diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 14c1d41256..0bb5e93b54 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -441,12 +441,17 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage return usage.Usage() request_tokens = getattr(response_usage, 'input_tokens', None) + cache_creation_input_tokens = getattr(response_usage, 'cache_creation_input_tokens', None) + cache_read_input_tokens = getattr(response_usage, 'cache_read_input_tokens', None) + + total_request_tokens = (request_tokens or 0) + (cache_creation_input_tokens or 0) + (cache_read_input_tokens or 0) return usage.Usage( # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr - request_tokens=request_tokens, + request_tokens=total_request_tokens, response_tokens=response_usage.output_tokens, - total_tokens=(request_tokens or 0) + response_usage.output_tokens, + cached_tokens=cache_read_input_tokens, + total_tokens=total_request_tokens + response_usage.output_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 0013fc4349..d7dae49f4c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -769,6 +769,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: request_tokens=metadata.get('prompt_token_count', 0), response_tokens=metadata.get('candidates_token_count', 0), total_tokens=metadata.get('total_token_count', 0), + cached_tokens=metadata.get('cached_content_token_count', 0), details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 0e6ee6daaa..03fe4e05ce 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -508,9 +508,16 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usa details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) if response_usage.prompt_tokens_details is not None: details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True)) + cached_tokens = 0 + if ( + response_usage.prompt_tokens_details is not None + and response_usage.prompt_tokens_details.cached_tokens is not None + ): + cached_tokens = response_usage.prompt_tokens_details.cached_tokens return usage.Usage( request_tokens=response_usage.prompt_tokens, response_tokens=response_usage.completion_tokens, + cached_tokens=cached_tokens, total_tokens=response_usage.total_tokens, details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index f7c184890a..ea333d5dcf 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -23,6 +23,8 @@ class Usage: """Tokens used in processing requests.""" response_tokens: int | None = None """Tokens used in generating responses.""" + cached_tokens: int | None = None + """Number of input tokens that were a cache hit.""" total_tokens: int | None = None """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" details: dict[str, int] | None = None @@ -36,7 +38,7 @@ def incr(self, incr_usage: Usage, *, requests: int = 0) -> None: requests: The number of requests to increment by in addition to `incr_usage.requests`. """ self.requests += requests - for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': + for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens', 'cached_tokens': self_value = getattr(self, f) other_value = getattr(incr_usage, f) if self_value is not None or other_value is not None: @@ -61,6 +63,7 @@ def opentelemetry_attributes(self) -> dict[str, int]: result = { 'gen_ai.usage.input_tokens': self.request_tokens, 'gen_ai.usage.output_tokens': self.response_tokens, + 'gen_ai.usage.cached_tokens': self.cached_tokens, } for key, value in (self.details or {}).items(): result[f'gen_ai.usage.details.{key}'] = value diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 3631a164ab..82810d851b 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -393,7 +393,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: assert isinstance(n, BaseNode) n = await run.next() - assert n == snapshot(End(None)) + assert n == snapshot(End(data=None)) with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'): await run.next() diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 4f1a7579c4..182fb20f15 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -287,7 +287,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]: node = Foo() async with graph.iter(node, persistence=sp) as run: end = await run.next() - assert end == snapshot(End(123)) + assert end == snapshot(End(data=123)) msg = "Incorrect snapshot status 'success', must be 'created' or 'pending'." with pytest.raises(GraphNodeStatusError, match=msg): diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index b6265a962a..b1c0230984 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -467,7 +467,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ), ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=1, response_tokens=2, cached_tokens=0, total_tokens=3) + ) result = await agent.run('Hello', message_history=result.new_messages()) assert result.data == 'Hello world' @@ -613,7 +615,9 @@ async def get_location(loc_name: str) -> str: ), ] ) - assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot( + Usage(requests=3, request_tokens=3, response_tokens=6, cached_tokens=0, total_tokens=9) + ) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): @@ -654,12 +658,16 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): 'Hello world', ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=2, response_tokens=4, cached_tokens=0, total_tokens=6) + ) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=2, response_tokens=4, cached_tokens=0, total_tokens=6) + ) async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): @@ -691,7 +699,9 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot(['abc', 'abc€def', 'abc€def']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=2, response_tokens=4, cached_tokens=0, total_tokens=6) + ) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -721,7 +731,9 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot([(1, 2), (1, 2)]) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=1, response_tokens=2, cached_tokens=0, total_tokens=3) + ) async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): @@ -762,7 +774,9 @@ async def bar(y: str) -> str: async with agent.run_stream('Hello') as result: response = await result.get_data() assert response == snapshot((1, 2)) - assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot( + Usage(requests=2, request_tokens=3, response_tokens=6, cached_tokens=0, total_tokens=9) + ) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index ba56bc4d52..ea885a0f2e 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -212,7 +212,9 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.data == 'world' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=2, response_tokens=1, cached_tokens=0, total_tokens=3) + ) async def test_request_structured_response(allow_model_requests: None): @@ -380,6 +382,7 @@ async def get_location(loc_name: str) -> str: requests=3, request_tokens=5, response_tokens=3, + cached_tokens=3, total_tokens=9, details={'cached_tokens': 3}, ) @@ -416,7 +419,9 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=6, response_tokens=3, cached_tokens=0, total_tokens=9) + ) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -487,7 +492,9 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=20, response_tokens=10, cached_tokens=0, total_tokens=30) + ) # double check usage matches stream count assert result.usage().response_tokens == len(stream) @@ -543,7 +550,9 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot( + Usage(requests=1, request_tokens=6, response_tokens=3, cached_tokens=0, total_tokens=9) + ) @pytest.mark.parametrize('system_prompt_role', ['system', 'developer', 'user', None])