diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c74a63a3b..b9aae9b06 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -55,13 +55,13 @@ "anthropic.claude", ] -# Cache of model IDs that do not support the CountTokens API. -_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set() +# Cache of model IDs for which CountTokens API calls should be skipped. +_SKIP_COUNT_TOKENS_MODELS: set[str] = set() -def _clear_unsupported_count_tokens_cache() -> None: - """Clear the cache of model IDs that do not support the CountTokens API.""" - _UNSUPPORTED_COUNT_TOKENS_MODELS.clear() +def _clear_skip_count_tokens_cache() -> None: + """Clear the cache of model IDs for which CountTokens API calls should be skipped.""" + _SKIP_COUNT_TOKENS_MODELS.clear() T = TypeVar("T", bound=BaseModel) @@ -803,7 +803,7 @@ async def count_tokens( model_id: str = self.config["model_id"] - if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS: + if model_id in _SKIP_COUNT_TOKENS_MODELS: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: @@ -833,6 +833,17 @@ async def count_tokens( return total_tokens except Exception as e: if ( + isinstance(e, ClientError) + and e.response.get("Error", {}).get("Code") == "AccessDeniedException" + ): + logger.warning( + "model_id=<%s> | bedrock:CountTokens permission denied," + " falling back to heuristic estimation: %s", + model_id, + e, + ) + _SKIP_COUNT_TOKENS_MODELS.add(model_id) + elif ( isinstance(e, ClientError) and e.response.get("Error", {}).get("Code") == "ValidationException" and "doesn't support counting tokens" in str(e) @@ -842,7 +853,7 @@ async def count_tokens( " falling back to estimation", model_id, ) - _UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id) + _SKIP_COUNT_TOKENS_MODELS.add(model_id) else: logger.debug( "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2f1f7d1f1..b65d77234 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -19,7 +19,7 @@ DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, - _clear_unsupported_count_tokens_cache, + _clear_skip_count_tokens_cache, ) from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec @@ -3336,9 +3336,9 @@ class TestCountTokens: @pytest.fixture(autouse=True) def clean_cache(self): - _clear_unsupported_count_tokens_cache() + _clear_skip_count_tokens_cache() yield - _clear_unsupported_count_tokens_cache() + _clear_skip_count_tokens_cache() @pytest.fixture def model_with_client(self, bedrock_client, model_id): @@ -3473,6 +3473,54 @@ async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_clien await model.count_tokens(messages=messages) assert bedrock_client.count_tokens.call_count == 1 + @pytest.mark.asyncio + async def test_caches_model_id_when_access_denied(self, bedrock_client, messages): + model = BedrockModel(model_id="access-denied-cache-test-model") + bedrock_client.count_tokens.side_effect = ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" + " to perform: bedrock:CountTokens", + } + }, + "CountTokens", + ) + + # First call: hits API, gets error, caches + await model.count_tokens(messages=messages) + bedrock_client.count_tokens.assert_called_once() + + # Reset mock to clearly verify second call doesn't hit the API + bedrock_client.count_tokens.reset_mock() + + # Second call: skips API entirely due to caching + result = await model.count_tokens(messages=messages) + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_access_denied_logs_warning_with_full_error( + self, model_with_client, bedrock_client, messages, caplog + ): + error_message = ( + "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" + " to perform: bedrock:CountTokens" + ) + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": error_message}}, + "CountTokens", + ) + + with caplog.at_level(logging.WARNING, logger="strands.models.bedrock"): + await model_with_client.count_tokens(messages=messages) + + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warning_records) == 1 + assert "bedrock:CountTokens permission denied" in warning_records[0].message + assert error_message in warning_records[0].message + @pytest.mark.asyncio async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): model = BedrockModel(model_id="transient-error-test-model")