Skip to content

Commit aafa80b

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: stream in litellm + adk and add corresponding integration tests
Fixes #1368 PiperOrigin-RevId: 772218385
1 parent 4bda245 commit aafa80b

File tree

3 files changed

+125
-12
lines changed

3 files changed

+125
-12
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,11 +739,12 @@ async def generate_content_async(
739739
_message_to_generate_content_response(
740740
ChatCompletionAssistantMessage(
741741
role="assistant",
742-
content="",
742+
content=text,
743743
tool_calls=tool_calls,
744744
)
745745
)
746746
)
747+
text = ""
747748
function_calls.clear()
748749
elif finish_reason == "stop" and text:
749750
aggregated_llm_response = _message_to_generate_content_response(

tests/integration/models/test_litellm_no_function.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,26 @@
2020
from google.genai.types import Part
2121
import pytest
2222

23-
_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"
2423

24+
_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"
2525

2626
_SYSTEM_PROMPT = """You are a helpful assistant."""
2727

2828

29+
def get_weather(city: str) -> str:
30+
"""Simulates a web search. Use it get information on weather.
31+
32+
Args:
33+
city: A string containing the location to get weather information for.
34+
35+
Returns:
36+
A string with the simulated weather information for the queried city.
37+
"""
38+
if "sf" in city.lower() or "san francisco" in city.lower():
39+
return "It's 70 degrees and foggy."
40+
return "It's 80 degrees and sunny."
41+
42+
2943
@pytest.fixture
3044
def oss_llm():
3145
return LiteLlm(model=_TEST_MODEL_NAME)
@@ -44,17 +58,57 @@ def llm_request():
4458
)
4559

4660

61+
@pytest.fixture
62+
def llm_request_with_tools():
63+
return LlmRequest(
64+
model=_TEST_MODEL_NAME,
65+
contents=[
66+
Content(
67+
role="user",
68+
parts=[
69+
Part.from_text(text="What is the weather in San Francisco?")
70+
],
71+
)
72+
],
73+
config=types.GenerateContentConfig(
74+
temperature=0.1,
75+
response_modalities=[types.Modality.TEXT],
76+
system_instruction=_SYSTEM_PROMPT,
77+
tools=[
78+
types.Tool(
79+
function_declarations=[
80+
types.FunctionDeclaration(
81+
name="get_weather",
82+
description="Get the weather in a given location",
83+
parameters=types.Schema(
84+
type=types.Type.OBJECT,
85+
properties={
86+
"city": types.Schema(
87+
type=types.Type.STRING,
88+
description=(
89+
"The city to get the weather for."
90+
),
91+
),
92+
},
93+
required=["city"],
94+
),
95+
)
96+
]
97+
)
98+
],
99+
),
100+
)
101+
102+
47103
@pytest.mark.asyncio
48104
async def test_generate_content_async(oss_llm, llm_request):
49105
async for response in oss_llm.generate_content_async(llm_request):
50106
assert isinstance(response, LlmResponse)
51107
assert response.content.parts[0].text
52108

53109

54-
# Note that, this test disabled streaming because streaming is not supported
55-
# properly in the current test model for now.
56110
@pytest.mark.asyncio
57-
async def test_generate_content_async_stream(oss_llm, llm_request):
111+
async def test_generate_content_async(oss_llm, llm_request):
58112
responses = [
59113
resp
60114
async for resp in oss_llm.generate_content_async(
@@ -63,3 +117,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request):
63117
]
64118
part = responses[0].content.parts[0]
65119
assert len(part.text) > 0
120+
121+
122+
@pytest.mark.asyncio
123+
async def test_generate_content_async_with_tools(
124+
oss_llm, llm_request_with_tools
125+
):
126+
responses = [
127+
resp
128+
async for resp in oss_llm.generate_content_async(
129+
llm_request_with_tools, stream=False
130+
)
131+
]
132+
function_call = responses[0].content.parts[0].function_call
133+
assert function_call.name == "get_weather"
134+
assert function_call.args["city"] == "San Francisco"
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_generate_content_async_stream(oss_llm, llm_request):
139+
responses = [
140+
resp
141+
async for resp in oss_llm.generate_content_async(llm_request, stream=True)
142+
]
143+
text = ""
144+
for i in range(len(responses) - 1):
145+
assert responses[i].partial is True
146+
assert responses[i].content.parts[0].text
147+
text += responses[i].content.parts[0].text
148+
149+
# Last message should be accumulated text
150+
assert responses[-1].content.parts[0].text == text
151+
assert not responses[-1].partial
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_generate_content_async_stream_with_tools(
156+
oss_llm, llm_request_with_tools
157+
):
158+
responses = [
159+
resp
160+
async for resp in oss_llm.generate_content_async(
161+
llm_request_with_tools, stream=True
162+
)
163+
]
164+
function_call = responses[-1].content.parts[0].function_call
165+
assert function_call.name == "get_weather"
166+
assert function_call.args["city"] == "San Francisco"

tests/integration/models/test_litellm_with_function.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from google.adk.models import LlmRequest
16-
from google.adk.models import LlmResponse
1716
from google.adk.models.lite_llm import LiteLlm
1817
from google.genai import types
1918
from google.genai.types import Content
@@ -23,12 +22,11 @@
2322

2423
litellm.add_function_to_prompt = True
2524

26-
_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"
27-
25+
_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"
2826

2927
_SYSTEM_PROMPT = """
3028
You are a helpful assistant, and call tools optionally.
31-
If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs.
29+
If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs.
3230
"""
3331

3432

@@ -40,7 +38,7 @@
4038
"properties": {
4139
"city": {
4240
"type": "string",
43-
"description": "The city, e.g. San Francisco",
41+
"description": "The city to get the weather for.",
4442
},
4543
},
4644
"required": ["city"],
@@ -87,8 +85,6 @@ def llm_request():
8785
)
8886

8987

90-
# Note that, this test disabled streaming because streaming is not supported
91-
# properly in the current test model for now.
9288
@pytest.mark.asyncio
9389
async def test_generate_content_asyn_with_function(
9490
oss_llm_with_function, llm_request
@@ -102,3 +98,18 @@ async def test_generate_content_asyn_with_function(
10298
function_call = responses[0].content.parts[0].function_call
10399
assert function_call.name == "get_weather"
104100
assert function_call.args["city"] == "San Francisco"
101+
102+
103+
@pytest.mark.asyncio
104+
async def test_generate_content_asyn_stream_with_function(
105+
oss_llm_with_function, llm_request
106+
):
107+
responses = [
108+
resp
109+
async for resp in oss_llm_with_function.generate_content_async(
110+
llm_request, stream=True
111+
)
112+
]
113+
function_call = responses[-1].content.parts[0].function_call
114+
assert function_call.name == "get_weather"
115+
assert function_call.args["city"] == "San Francisco"

0 commit comments

Comments
 (0)