diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 61156849e7..9aa96ce217 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -409,13 +409,27 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage if response_usage is None: return usage.Usage() - request_tokens = getattr(response_usage, 'input_tokens', None) + # Store all integer-typed usage values in the details dict + response_usage_dict = response_usage.model_dump() + details: dict[str, int] = {} + for key, value in response_usage_dict.items(): + if isinstance(value, int): + details[key] = value + + # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence the getattr call + # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens + # This approach maintains request_tokens as the count of all input tokens, with cached counts as details + request_tokens = ( + getattr(response_usage, 'input_tokens', 0) + + (getattr(response_usage, 'cache_creation_input_tokens', 0) or 0) # These can be missing, None, or int + + (getattr(response_usage, 'cache_read_input_tokens', 0) or 0) + ) return usage.Usage( - # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr - request_tokens=request_tokens, + request_tokens=request_tokens or None, response_tokens=response_usage.output_tokens, - total_tokens=(request_tokens or 0) + response_usage.output_tokens, + total_tokens=request_tokens + response_usage.output_tokens, + details=details or None, ) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 2e7d660711..b5d970f2a7 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -141,14 +141,29 @@ async def test_sync_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15)) - + assert result.usage() == snapshot( + Usage( + requests=1, + request_tokens=5, + response_tokens=10, + total_tokens=15, + details={'input_tokens': 5, 'output_tokens': 10}, + ) + ) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15)) + assert result.usage() == snapshot( + Usage( + requests=1, + request_tokens=5, + response_tokens=10, + total_tokens=15, + details={'input_tokens': 5, 'output_tokens': 10}, + ) + ) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -167,6 +182,38 @@ async def test_sync_request_text_response(allow_model_requests: None): ) +async def test_async_request_prompt_caching(allow_model_requests: None): + c = completion_message( + [TextBlock(text='world', type='text')], + usage=AnthropicUsage( + input_tokens=3, + output_tokens=5, + cache_creation_input_tokens=4, + cache_read_input_tokens=6, + ), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + result = await agent.run('hello') + assert result.output == 'world' + assert result.usage() == snapshot( + Usage( + requests=1, + request_tokens=13, + response_tokens=5, + total_tokens=18, + details={ + 'input_tokens': 3, + 'output_tokens': 5, + 'cache_creation_input_tokens': 4, + 'cache_read_input_tokens': 6, + }, + ) + ) + + async def test_async_request_text_response(allow_model_requests: None): c = completion_message( [TextBlock(text='world', type='text')], @@ -178,7 +225,15 @@ async def test_async_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=3, response_tokens=5, total_tokens=8)) + assert result.usage() == snapshot( + Usage( + requests=1, + request_tokens=3, + response_tokens=5, + total_tokens=8, + details={'input_tokens': 3, 'output_tokens': 5}, + ) + ) async def test_request_structured_response(allow_model_requests: None): @@ -551,7 +606,15 @@ async def my_tool(first: str, second: str) -> int: ] ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=2, request_tokens=20, response_tokens=5, total_tokens=25)) + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=20, + response_tokens=5, + total_tokens=25, + details={'input_tokens': 20, 'output_tokens': 5}, + ) + ) assert tool_called