|
8 | 8 | from typing import Any, AsyncIterator, Optional, Union, no_type_check
|
9 | 9 |
|
10 | 10 | import nexusrpc
|
| 11 | +import pydantic |
11 | 12 | import pytest
|
12 | 13 | from agents import (
|
13 | 14 | Agent,
|
@@ -2221,3 +2222,59 @@ def provide(
|
2221 | 2222 | async for e in workflow_handle.fetch_history_events():
|
2222 | 2223 | if e.HasField("activity_task_scheduled_event_attributes"):
|
2223 | 2224 | assert e.user_metadata.summary.data == b'"My summary"'
|
| 2225 | + |
| 2226 | + |
| 2227 | +class OutputType(pydantic.BaseModel): |
| 2228 | + answer: str |
| 2229 | + model_config = ConfigDict(extra="forbid") # Forbid additional properties |
| 2230 | + |
| 2231 | + |
| 2232 | +@workflow.defn |
| 2233 | +class OutputTypeWorkflow: |
| 2234 | + @workflow.run |
| 2235 | + async def run(self) -> OutputType: |
| 2236 | + agent: Agent = Agent( |
| 2237 | + name="Assistant", |
| 2238 | + instructions="You are a helpful assistant, adhere to the json schema output", |
| 2239 | + output_type=OutputType, |
| 2240 | + ) |
| 2241 | + result = await Runner.run( |
| 2242 | + starting_agent=agent, |
| 2243 | + input="Hello!", |
| 2244 | + ) |
| 2245 | + return result.final_output |
| 2246 | + |
| 2247 | + |
| 2248 | +class OutputTypeModel(StaticTestModel): |
| 2249 | + responses = [ |
| 2250 | + ResponseBuilders.output_message( |
| 2251 | + '{"answer": "My answer"}', |
| 2252 | + ), |
| 2253 | + ] |
| 2254 | + |
| 2255 | + |
| 2256 | +async def test_output_type(client: Client): |
| 2257 | + new_config = client.config() |
| 2258 | + new_config["plugins"] = [ |
| 2259 | + openai_agents.OpenAIAgentsPlugin( |
| 2260 | + model_params=ModelActivityParameters( |
| 2261 | + start_to_close_timeout=timedelta(seconds=120), |
| 2262 | + ), |
| 2263 | + model_provider=TestModelProvider(OutputTypeModel()), |
| 2264 | + ) |
| 2265 | + ] |
| 2266 | + client = Client(**new_config) |
| 2267 | + |
| 2268 | + async with new_worker( |
| 2269 | + client, |
| 2270 | + OutputTypeWorkflow, |
| 2271 | + ) as worker: |
| 2272 | + workflow_handle = await client.start_workflow( |
| 2273 | + OutputTypeWorkflow.run, |
| 2274 | + id=f"output-type-{uuid.uuid4()}", |
| 2275 | + task_queue=worker.task_queue, |
| 2276 | + execution_timeout=timedelta(seconds=10), |
| 2277 | + ) |
| 2278 | + result = await workflow_handle.result() |
| 2279 | + assert isinstance(result, OutputType) |
| 2280 | + assert result.answer == "My answer" |
0 commit comments