|
41 | 41 | SseServerParams,
|
42 | 42 | )
|
43 | 43 | from pydantic import BaseModel, ValidationError
|
44 |
| -from utils import FileLogHandler |
| 44 | +from utils import FileLogHandler, compare_messages, compare_task_results |
45 | 45 |
|
46 | 46 | logger = logging.getLogger(EVENT_LOGGER_NAME)
|
47 | 47 | logger.setLevel(logging.DEBUG)
|
@@ -180,9 +180,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
180 | 180 | index = 0
|
181 | 181 | async for message in agent.run_stream(task="task"):
|
182 | 182 | if isinstance(message, TaskResult):
|
183 |
| - assert message == result |
| 183 | + assert compare_task_results(message, result) |
184 | 184 | else:
|
185 |
| - assert message == result.messages[index] |
| 185 | + assert compare_messages(message, result.messages[index]) |
186 | 186 | index += 1
|
187 | 187 |
|
188 | 188 | # Test state saving and loading.
|
@@ -273,9 +273,9 @@ async def test_run_with_tools_and_reflection() -> None:
|
273 | 273 | index = 0
|
274 | 274 | async for message in agent.run_stream(task="task"):
|
275 | 275 | if isinstance(message, TaskResult):
|
276 |
| - assert message == result |
| 276 | + assert compare_task_results(message, result) |
277 | 277 | else:
|
278 |
| - assert message == result.messages[index] |
| 278 | + assert compare_messages(message, result.messages[index]) |
279 | 279 | index += 1
|
280 | 280 |
|
281 | 281 | # Test state saving and loading.
|
@@ -363,9 +363,9 @@ async def test_run_with_parallel_tools() -> None:
|
363 | 363 | index = 0
|
364 | 364 | async for message in agent.run_stream(task="task"):
|
365 | 365 | if isinstance(message, TaskResult):
|
366 |
| - assert message == result |
| 366 | + assert compare_task_results(message, result) |
367 | 367 | else:
|
368 |
| - assert message == result.messages[index] |
| 368 | + assert compare_messages(message, result.messages[index]) |
369 | 369 | index += 1
|
370 | 370 |
|
371 | 371 | # Test state saving and loading.
|
@@ -446,9 +446,9 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
|
446 | 446 | index = 0
|
447 | 447 | async for message in agent.run_stream(task="task"):
|
448 | 448 | if isinstance(message, TaskResult):
|
449 |
| - assert message == result |
| 449 | + assert compare_task_results(message, result) |
450 | 450 | else:
|
451 |
| - assert message == result.messages[index] |
| 451 | + assert compare_messages(message, result.messages[index]) |
452 | 452 | index += 1
|
453 | 453 |
|
454 | 454 | # Test state saving and loading.
|
@@ -560,9 +560,9 @@ async def test_run_with_workbench() -> None:
|
560 | 560 | index = 0
|
561 | 561 | async for message in agent.run_stream(task="task"):
|
562 | 562 | if isinstance(message, TaskResult):
|
563 |
| - assert message == result |
| 563 | + assert compare_task_results(message, result) |
564 | 564 | else:
|
565 |
| - assert message == result.messages[index] |
| 565 | + assert compare_messages(message, result.messages[index]) |
566 | 566 | index += 1
|
567 | 567 |
|
568 | 568 | # Test state saving and loading.
|
@@ -779,9 +779,9 @@ async def test_handoffs() -> None:
|
779 | 779 | index = 0
|
780 | 780 | async for message in tool_use_agent.run_stream(task="task"):
|
781 | 781 | if isinstance(message, TaskResult):
|
782 |
| - assert message == result |
| 782 | + assert compare_task_results(message, result) |
783 | 783 | else:
|
784 |
| - assert message == result.messages[index] |
| 784 | + assert compare_messages(message, result.messages[index]) |
785 | 785 | index += 1
|
786 | 786 |
|
787 | 787 |
|
@@ -852,9 +852,9 @@ async def test_handoff_with_tool_call_context() -> None:
|
852 | 852 | index = 0
|
853 | 853 | async for message in tool_use_agent.run_stream(task="task"):
|
854 | 854 | if isinstance(message, TaskResult):
|
855 |
| - assert message == result |
| 855 | + assert compare_task_results(message, result) |
856 | 856 | else:
|
857 |
| - assert message == result.messages[index] |
| 857 | + assert compare_messages(message, result.messages[index]) |
858 | 858 | index += 1
|
859 | 859 |
|
860 | 860 |
|
@@ -927,9 +927,9 @@ def _next_action(action: str) -> str:
|
927 | 927 | index = 0
|
928 | 928 | async for message in tool_use_agent.run_stream(task="task"):
|
929 | 929 | if isinstance(message, TaskResult):
|
930 |
| - assert message == result |
| 930 | + assert compare_task_results(message, result) |
931 | 931 | else:
|
932 |
| - assert message == result.messages[index] |
| 932 | + assert compare_messages(message, result.messages[index]) |
933 | 933 | index += 1
|
934 | 934 |
|
935 | 935 |
|
@@ -1004,9 +1004,9 @@ def _next_action(action: str) -> Dict[str, str]:
|
1004 | 1004 | index = 0
|
1005 | 1005 | async for message in tool_use_agent.run_stream(task="task"):
|
1006 | 1006 | if isinstance(message, TaskResult):
|
1007 |
| - assert message == result |
| 1007 | + assert compare_task_results(message, result) |
1008 | 1008 | else:
|
1009 |
| - assert message == result.messages[index] |
| 1009 | + assert compare_messages(message, result.messages[index]) |
1010 | 1010 | index += 1
|
1011 | 1011 |
|
1012 | 1012 |
|
@@ -1161,9 +1161,9 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
1161 | 1161 | index = 0
|
1162 | 1162 | async for message in agent.run_stream(task=messages):
|
1163 | 1163 | if isinstance(message, TaskResult):
|
1164 |
| - assert message == result |
| 1164 | + assert compare_task_results(message, result) |
1165 | 1165 | else:
|
1166 |
| - assert message == result.messages[index] |
| 1166 | + assert compare_messages(message, result.messages[index]) |
1167 | 1167 | index += 1
|
1168 | 1168 |
|
1169 | 1169 |
|
|
0 commit comments