Skip to content

Commit 7be26da

Browse files
authored
Add a test for output type coercion (#1076)
* Add a test for output type coercion * remove prints
1 parent e92f391 commit 7be26da

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

tests/contrib/openai_agents/test_openai.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, AsyncIterator, Optional, Union, no_type_check
99

1010
import nexusrpc
11+
import pydantic
1112
import pytest
1213
from agents import (
1314
Agent,
@@ -2221,3 +2222,59 @@ def provide(
22212222
async for e in workflow_handle.fetch_history_events():
22222223
if e.HasField("activity_task_scheduled_event_attributes"):
22232224
assert e.user_metadata.summary.data == b'"My summary"'
2225+
2226+
2227+
class OutputType(pydantic.BaseModel):
2228+
answer: str
2229+
model_config = ConfigDict(extra="forbid") # Forbid additional properties
2230+
2231+
2232+
@workflow.defn
2233+
class OutputTypeWorkflow:
2234+
@workflow.run
2235+
async def run(self) -> OutputType:
2236+
agent: Agent = Agent(
2237+
name="Assistant",
2238+
instructions="You are a helpful assistant, adhere to the json schema output",
2239+
output_type=OutputType,
2240+
)
2241+
result = await Runner.run(
2242+
starting_agent=agent,
2243+
input="Hello!",
2244+
)
2245+
return result.final_output
2246+
2247+
2248+
class OutputTypeModel(StaticTestModel):
2249+
responses = [
2250+
ResponseBuilders.output_message(
2251+
'{"answer": "My answer"}',
2252+
),
2253+
]
2254+
2255+
2256+
async def test_output_type(client: Client):
2257+
new_config = client.config()
2258+
new_config["plugins"] = [
2259+
openai_agents.OpenAIAgentsPlugin(
2260+
model_params=ModelActivityParameters(
2261+
start_to_close_timeout=timedelta(seconds=120),
2262+
),
2263+
model_provider=TestModelProvider(OutputTypeModel()),
2264+
)
2265+
]
2266+
client = Client(**new_config)
2267+
2268+
async with new_worker(
2269+
client,
2270+
OutputTypeWorkflow,
2271+
) as worker:
2272+
workflow_handle = await client.start_workflow(
2273+
OutputTypeWorkflow.run,
2274+
id=f"output-type-{uuid.uuid4()}",
2275+
task_queue=worker.task_queue,
2276+
execution_timeout=timedelta(seconds=10),
2277+
)
2278+
result = await workflow_handle.result()
2279+
assert isinstance(result, OutputType)
2280+
assert result.answer == "My answer"

0 commit comments

Comments
 (0)