Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@
"anthropic.claude",
]

# Cache of model IDs that do not support the CountTokens API.
_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set()
# Cache of model IDs for which CountTokens API calls should be skipped.
_SKIP_COUNT_TOKENS_MODELS: set[str] = set()


def _clear_unsupported_count_tokens_cache() -> None:
"""Clear the cache of model IDs that do not support the CountTokens API."""
_UNSUPPORTED_COUNT_TOKENS_MODELS.clear()
def _clear_skip_count_tokens_cache() -> None:
"""Clear the cache of model IDs for which CountTokens API calls should be skipped."""
_SKIP_COUNT_TOKENS_MODELS.clear()


T = TypeVar("T", bound=BaseModel)
Expand Down Expand Up @@ -803,7 +803,7 @@ async def count_tokens(

model_id: str = self.config["model_id"]

if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS:
if model_id in _SKIP_COUNT_TOKENS_MODELS:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
Expand Down Expand Up @@ -833,6 +833,17 @@ async def count_tokens(
return total_tokens
except Exception as e:
if (
isinstance(e, ClientError)
and e.response.get("Error", {}).get("Code") == "AccessDeniedException"
):
logger.warning(
"model_id=<%s> | bedrock:CountTokens permission denied,"
" falling back to heuristic estimation: %s",
model_id,
e,
)
_SKIP_COUNT_TOKENS_MODELS.add(model_id)
elif (
isinstance(e, ClientError)
and e.response.get("Error", {}).get("Code") == "ValidationException"
and "doesn't support counting tokens" in str(e)
Expand All @@ -842,7 +853,7 @@ async def count_tokens(
" falling back to estimation",
model_id,
)
_UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id)
_SKIP_COUNT_TOKENS_MODELS.add(model_id)
else:
logger.debug(
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",
Expand Down
54 changes: 51 additions & 3 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
DEFAULT_BEDROCK_MODEL_ID,
DEFAULT_BEDROCK_REGION,
DEFAULT_READ_TIMEOUT,
_clear_unsupported_count_tokens_cache,
_clear_skip_count_tokens_cache,
)
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
from strands.types.tools import ToolSpec
Expand Down Expand Up @@ -3336,9 +3336,9 @@ class TestCountTokens:

@pytest.fixture(autouse=True)
def clean_cache(self):
_clear_unsupported_count_tokens_cache()
_clear_skip_count_tokens_cache()
yield
_clear_unsupported_count_tokens_cache()
_clear_skip_count_tokens_cache()

@pytest.fixture
def model_with_client(self, bedrock_client, model_id):
Expand Down Expand Up @@ -3473,6 +3473,54 @@ async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_clien
await model.count_tokens(messages=messages)
assert bedrock_client.count_tokens.call_count == 1

@pytest.mark.asyncio
async def test_caches_model_id_when_access_denied(self, bedrock_client, messages):
model = BedrockModel(model_id="access-denied-cache-test-model")
bedrock_client.count_tokens.side_effect = ClientError(
{
"Error": {
"Code": "AccessDeniedException",
"Message": "User: arn:aws:sts::123456789012:assumed-role/role is not authorized"
" to perform: bedrock:CountTokens",
}
},
"CountTokens",
)

# First call: hits API, gets error, caches
await model.count_tokens(messages=messages)
bedrock_client.count_tokens.assert_called_once()

# Reset mock to clearly verify second call doesn't hit the API
bedrock_client.count_tokens.reset_mock()

# Second call: skips API entirely due to caching
result = await model.count_tokens(messages=messages)
bedrock_client.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_access_denied_logs_warning_with_full_error(
self, model_with_client, bedrock_client, messages, caplog
):
error_message = (
"User: arn:aws:sts::123456789012:assumed-role/role is not authorized"
" to perform: bedrock:CountTokens"
)
bedrock_client.count_tokens.side_effect = ClientError(
{"Error": {"Code": "AccessDeniedException", "Message": error_message}},
"CountTokens",
)

with caplog.at_level(logging.WARNING, logger="strands.models.bedrock"):
await model_with_client.count_tokens(messages=messages)

warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warning_records) == 1
assert "bedrock:CountTokens permission denied" in warning_records[0].message
assert error_message in warning_records[0].message

@pytest.mark.asyncio
async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages):
model = BedrockModel(model_id="transient-error-test-model")
Expand Down
Loading