Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 194 additions & 47 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Message, Messages
from ..types.content import ContentBlock, Messages
from ..types.exceptions import (
ContextWindowOverflowException,
ModelThrottledException,
)
from ..types.streaming import CitationsDelta, StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec
from ..types.tools import ToolChoice, ToolSpec
from ._validation import validate_config_keys
from .model import Model

Expand Down Expand Up @@ -185,17 +185,6 @@ def get_config(self) -> BedrockConfig:
"""
return self.config

def _should_include_tool_result_status(self) -> bool:
"""Determine whether to include tool result status based on current config."""
include_status = self.config.get("include_tool_result_status", "auto")

if include_status is True:
return True
elif include_status is False:
return False
else: # "auto"
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)

def format_request(
self,
messages: Messages,
Expand Down Expand Up @@ -281,14 +270,12 @@ def format_request(
),
}

def _format_bedrock_messages(self, messages: Messages) -> Messages:
def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
"""Format messages for Bedrock API compatibility.

This function ensures messages conform to Bedrock's expected format by:
- Filtering out SDK_UNKNOWN_MEMBER content blocks
- Cleaning tool result content blocks by removing additional fields that may be
useful for retaining information in hooks but would cause Bedrock validation
exceptions when presented with unexpected fields
- Eagerly filtering content blocks to only include Bedrock-supported fields
- Ensuring all message content blocks are properly formatted for the Bedrock API

Args:
Expand All @@ -298,17 +285,19 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
Messages formatted for Bedrock API compatibility

Note:
Bedrock will throw validation exceptions when presented with additional
unexpected fields in tool result blocks.
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
Unlike other APIs that ignore unknown fields, Bedrock only accepts a strict
subset of fields for each content block type and throws validation exceptions
when presented with unexpected fields. Therefore, we must eagerly filter all
content blocks to remove any additional fields before sending to Bedrock.
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html
"""
cleaned_messages = []
cleaned_messages: list[dict[str, Any]] = []

filtered_unknown_members = False
dropped_deepseek_reasoning_content = False

for message in messages:
cleaned_content: list[ContentBlock] = []
cleaned_content: list[dict[str, Any]] = []

for content_block in message["content"]:
# Filter out SDK_UNKNOWN_MEMBER content blocks
Expand All @@ -322,33 +311,13 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
dropped_deepseek_reasoning_content = True
continue

if "toolResult" in content_block:
# Create a new content block with only the cleaned toolResult
tool_result: ToolResult = content_block["toolResult"]
# Format content blocks for Bedrock API compatibility
formatted_content = self._format_request_message_content(content_block)
cleaned_content.append(formatted_content)

if self._should_include_tool_result_status():
# Include status field
cleaned_tool_result = ToolResult(
content=tool_result["content"],
toolUseId=tool_result["toolUseId"],
status=tool_result["status"],
)
else:
# Remove status field
cleaned_tool_result = ToolResult( # type: ignore[typeddict-item]
toolUseId=tool_result["toolUseId"], content=tool_result["content"]
)

cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
cleaned_content.append(cleaned_block)
else:
# Keep other content blocks as-is
cleaned_content.append(content_block)

# Create new message with cleaned content (skip if empty for DeepSeek)
# Create new message with cleaned content (skip if empty)
if cleaned_content:
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
cleaned_messages.append(cleaned_message)
cleaned_messages.append({"content": cleaned_content, "role": message["role"]})

if filtered_unknown_members:
logger.warning(
Expand All @@ -361,6 +330,184 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:

return cleaned_messages

def _should_include_tool_result_status(self) -> bool:
"""Determine whether to include tool result status based on current config."""
include_status = self.config.get("include_tool_result_status", "auto")

if include_status is True:
return True
elif include_status is False:
return False
else: # "auto"
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)

def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
"""Format a Bedrock content block.

Bedrock strictly validates content blocks and throws exceptions for unknown fields.
This function extracts only the fields that Bedrock supports for each content type.

Args:
content: Content block to format.

Returns:
Bedrock formatted content block.

Raises:
TypeError: If the content block type is not supported by Bedrock.
"""
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html
if "cachePoint" in content:
return {"cachePoint": {"type": content["cachePoint"]["type"]}}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html
if "document" in content:
document = content["document"]
result: dict[str, Any] = {}

# Handle required fields (all optional due to total=False)
if "name" in document:
result["name"] = document["name"]
if "format" in document:
result["format"] = document["format"]

# Handle source
if "source" in document:
result["source"] = {"bytes": document["source"]["bytes"]}

# Handle optional fields
if "citations" in document and document["citations"] is not None:
result["citations"] = {"enabled": document["citations"]["enabled"]}
if "context" in document:
result["context"] = document["context"]

return {"document": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailConverseContentBlock.html
if "guardContent" in content:
guard = content["guardContent"]
guard_text = guard["text"]
result = {"text": {"text": guard_text["text"], "qualifiers": guard_text["qualifiers"]}}
return {"guardContent": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html
if "image" in content:
image = content["image"]
source = image["source"]
formatted_source = {}
if "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": image["format"], "source": formatted_source}
return {"image": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
if "reasoningContent" in content:
reasoning = content["reasoningContent"]
result = {}

if "reasoningText" in reasoning:
reasoning_text = reasoning["reasoningText"]
result["reasoningText"] = {}
if "text" in reasoning_text:
result["reasoningText"]["text"] = reasoning_text["text"]
# Only include signature if truthy (avoid empty strings)
if reasoning_text.get("signature"):
result["reasoningText"]["signature"] = reasoning_text["signature"]

if "redactedContent" in reasoning:
result["redactedContent"] = reasoning["redactedContent"]

return {"reasoningContent": result}

# Pass through text and other simple content types
if "text" in content:
return {"text": content["text"]}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
if "toolResult" in content:
tool_result = content["toolResult"]
formatted_content: list[dict[str, Any]] = []
for tool_result_content in tool_result["content"]:
if "json" in tool_result_content:
# Handle json field since not in ContentBlock but valid in ToolResultContent
formatted_content.append({"json": tool_result_content["json"]})
else:
formatted_content.append(
self._format_request_message_content(cast(ContentBlock, tool_result_content))
)

result = {
"content": formatted_content,
"toolUseId": tool_result["toolUseId"],
}
if "status" in tool_result and self._should_include_tool_result_status():
result["status"] = tool_result["status"]
return {"toolResult": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolUseBlock.html
if "toolUse" in content:
tool_use = content["toolUse"]
return {
"toolUse": {
"input": tool_use["input"],
"name": tool_use["name"],
"toolUseId": tool_use["toolUseId"],
}
}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_VideoBlock.html
if "video" in content:
video = content["video"]
source = video["source"]
formatted_source = {}
if "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": video["format"], "source": formatted_source}
return {"video": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
if "citationsContent" in content:
citations = content["citationsContent"]
result = {}

if "citations" in citations:
result["citations"] = []
for citation in citations["citations"]:
filtered_citation: dict[str, Any] = {}
if "location" in citation:
location = citation["location"]
filtered_location = {}
# Filter location fields to only include Bedrock-supported ones
if "documentIndex" in location:
filtered_location["documentIndex"] = location["documentIndex"]
if "start" in location:
filtered_location["start"] = location["start"]
if "end" in location:
filtered_location["end"] = location["end"]
filtered_citation["location"] = filtered_location
if "sourceContent" in citation:
filtered_source_content: list[dict[str, Any]] = []
for source_content in citation["sourceContent"]:
if "text" in source_content:
filtered_source_content.append({"text": source_content["text"]})
if filtered_source_content:
filtered_citation["sourceContent"] = filtered_source_content
if "title" in citation:
filtered_citation["title"] = citation["title"]
result["citations"].append(filtered_citation)

if "content" in citations:
filtered_content: list[dict[str, Any]] = []
for generated_content in citations["content"]:
if "text" in generated_content:
filtered_content.append({"text": generated_content["text"]})
if filtered_content:
result["content"] = filtered_content

return {"citationsContent": result}

raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
"""Check if guardrail data contains any blocked policies.

Expand Down
Loading
Loading