From 32fe567e25cd72ad9c7d638d4107e789392cfa66 Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Wed, 19 Nov 2025 22:47:04 +0400 Subject: [PATCH 1/2] fix(gemini): support thought signatures for tool calls --- src/strands/models/gemini.py | 33 +++++++++++++++++++--------- tests/strands/models/test_gemini.py | 34 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index c24d91a0d..d146e29d1 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -141,14 +141,18 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ) if "toolUse" in content: - return genai.types.Part( + thought_signature = cast(Optional[str], content["toolUse"].get("thoughtSignature")) + + part = genai.types.Part( function_call=genai.types.FunctionCall( args=content["toolUse"]["input"], id=content["toolUse"]["toolUseId"], name=content["toolUse"]["name"], ), ) - + if thought_signature: + part.thought_signature = thought_signature.encode("utf-8") + return part raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def _format_request_content(self, messages: Messages) -> list[genai.types.Content]: @@ -268,14 +272,15 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: # that name be set in the equivalent FunctionResponse type. Consequently, we assign # function name to toolUseId in our tool use block. And another reason, function_call is # not guaranteed to have id populated. + tool_use: dict[str, Any] = { + "name": event["data"].function_call.name, + "toolUseId": event["data"].function_call.name, + } + if event.get("thought_signature"): + tool_use["thoughtSignature"] = event["thought_signature"] return { "contentBlockStart": { - "start": { - "toolUse": { - "name": event["data"].function_call.name, - "toolUseId": event["data"].function_call.name, - }, - }, + "start": {"toolUse": cast(Any, tool_use)}, }, } @@ -302,7 +307,7 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: if event["data"].thought_signature else {} ), - }, + } }, }, } @@ -378,10 +383,18 @@ async def stream( candidate = candidates[0] if candidates else None content = candidate.content if candidate else None parts = content.parts if content and content.parts else [] + thought_signature = getattr(candidate, "thought_signature", None) if candidate else None for part in parts: if part.function_call: - yield self._format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": part}) + yield self._format_chunk( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": part, + "thought_signature": thought_signature, + } + ) yield self._format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": part}) yield self._format_chunk({"chunk_type": "content_stop", "data_type": "tool", "data": part}) tool_used = True diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index a8f5351cc..3368f6dc3 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1,3 +1,4 @@ +import base64 import json import logging import unittest.mock @@ -637,3 +638,36 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap assert "Gemini API returned non-JSON error" in caplog.text assert f"error_message=<{error_message}>" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_request_preserves_tool_use_signature(gemini_client, model): + """Verify that a signature stored in a previous toolUse turn is sent back to the API.""" + signature_str = "original_signature_string" + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + "thoughtSignature": signature_str, + }, + }, + ], + } + ] + + await anext(model.stream(messages)) + + call_args = gemini_client.aio.models.generate_content_stream.call_args + _, kwargs = call_args + sent_contents = kwargs["contents"] + + tool_part = sent_contents[0]["parts"][0] + + expected_signature = base64.b64encode(signature_str.encode("utf-8")).decode("utf-8") + assert "function_call" in tool_part + assert tool_part["thought_signature"] == expected_signature From de86818de37685cba3df3b986f8c7a93fcd0b4e5 Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Thu, 20 Nov 2025 20:07:24 +0400 Subject: [PATCH 2/2] fix(gemini): fix tests and thought signature --- src/strands/models/gemini.py | 2 +- tests/strands/models/test_gemini.py | 73 ++++++++++++++++++++++++----- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index d146e29d1..31b31f2a8 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -277,7 +277,7 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: "toolUseId": event["data"].function_call.name, } if event.get("thought_signature"): - tool_use["thoughtSignature"] = event["thought_signature"] + tool_use["thoughtSignature"] = event["thought_signature"].decode("utf-8") return { "contentBlockStart": { "start": {"toolUse": cast(Any, tool_use)}, diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 3368f6dc3..32a93aa33 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -641,9 +641,9 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap @pytest.mark.asyncio -async def test_stream_request_preserves_tool_use_signature(gemini_client, model): - """Verify that a signature stored in a previous toolUse turn is sent back to the API.""" - signature_str = "original_signature_string" +async def test_stream_request_with_tool_use_and_thought_signature(gemini_client, model): + """Verify that a thought_signature in a toolUse is sent back to the API as a base64 string.""" + signature_input = "test_signature_to_send_back" messages = [ { "role": "assistant", @@ -653,21 +653,68 @@ async def test_stream_request_preserves_tool_use_signature(gemini_client, model) "toolUseId": "c1", "name": "calculator", "input": {"expression": "2+2"}, - "thoughtSignature": signature_str, - }, - }, + "thoughtSignature": signature_input, + } + } ], } ] await anext(model.stream(messages)) - call_args = gemini_client.aio.models.generate_content_stream.call_args - _, kwargs = call_args - sent_contents = kwargs["contents"] + called_with_kwargs = gemini_client.aio.models.generate_content_stream.call_args.kwargs + sent_contents_as_dict = called_with_kwargs["contents"] + + tool_use_part_as_dict = sent_contents_as_dict[0]["parts"][0] + + assert "function_call" in tool_use_part_as_dict + + expected_b64 = base64.b64encode(signature_input.encode("utf-8")).decode("utf-8") + assert tool_use_part_as_dict.get("thought_signature") == expected_b64 + + +@pytest.mark.asyncio +async def test_stream_response_tool_use_with_thought_signature(gemini_client, model, messages, agenerator, alist): + """Test that thought signature from response is properly captured and stored in toolUse.""" + mock_candidate = unittest.mock.Mock() + mock_candidate.finish_reason = "TOOL_USE" + + mock_fn = unittest.mock.Mock(args={"expression": "2+2"}) + mock_fn.name = "calculator" + + mock_part = unittest.mock.Mock() + mock_part.function_call = mock_fn + mock_part.text = None + mock_part.thought = False + + mock_candidate.content.parts = [mock_part] + mock_candidate.thought_signature = b"sig123" + + mock_response = unittest.mock.Mock() + mock_response.candidates = [mock_candidate] + + mock_meta = unittest.mock.Mock() + mock_meta.prompt_token_count = 1 + mock_meta.total_token_count = 2 + mock_response.usage_metadata = mock_meta - tool_part = sent_contents[0]["parts"][0] + gemini_client.aio.models.generate_content_stream.return_value = agenerator([mock_response]) - expected_signature = base64.b64encode(signature_str.encode("utf-8")).decode("utf-8") - assert "function_call" in tool_part - assert tool_part["thought_signature"] == expected_signature + tru_chunks = await alist(model.stream(messages)) + + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + { + "contentBlockStart": { + "start": {"toolUse": {"name": "calculator", "toolUseId": "calculator", "thoughtSignature": "sig123"}} + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 2}, "metrics": {"latencyMs": 0}}}, + ] + + assert tru_chunks == exp_chunks