Skip to content

Commit db125fb

Browse files
withsmilojackgerritsekzhu
authored
Add created_at to BaseChatMessage and BaseAgentEvent (#6557)
## Why are these changes needed? I added `created_at` to both BaseChatMessage and BaseAgentEvent classes that store the time these Pydantic model instances are generated. And then users will be able to use `created_at` to build up a customized external persisting state management layer for their case. ## Related issue number #6169 (reply in thread) ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
1 parent 726e0be commit db125fb

File tree

7 files changed

+113
-64
lines changed

7 files changed

+113
-64
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/messages.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class and includes specific fields relevant to the type of message being sent.
55
"""
66

77
from abc import ABC, abstractmethod
8+
from datetime import datetime, timezone
89
from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar
910

1011
from autogen_core import Component, ComponentBase, FunctionCall, Image
@@ -85,6 +86,9 @@ class BaseChatMessage(BaseMessage, ABC):
8586
metadata: Dict[str, str] = {}
8687
"""Additional metadata about the message."""
8788

89+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
90+
"""The time when the message was created."""
91+
8892
@abstractmethod
8993
def to_model_text(self) -> str:
9094
"""Convert the content of the message to text-only representation.
@@ -154,6 +158,9 @@ class BaseAgentEvent(BaseMessage, ABC):
154158
metadata: Dict[str, str] = {}
155159
"""Additional metadata about the message."""
156160

161+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
162+
"""The time when the message was created."""
163+
157164

158165
StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True)
159166
"""Type variable for structured content types."""

python/packages/autogen-agentchat/tests/test_assistant_agent.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
SseServerParams,
4242
)
4343
from pydantic import BaseModel, ValidationError
44-
from utils import FileLogHandler
44+
from utils import FileLogHandler, compare_messages, compare_task_results
4545

4646
logger = logging.getLogger(EVENT_LOGGER_NAME)
4747
logger.setLevel(logging.DEBUG)
@@ -180,9 +180,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
180180
index = 0
181181
async for message in agent.run_stream(task="task"):
182182
if isinstance(message, TaskResult):
183-
assert message == result
183+
assert compare_task_results(message, result)
184184
else:
185-
assert message == result.messages[index]
185+
assert compare_messages(message, result.messages[index])
186186
index += 1
187187

188188
# Test state saving and loading.
@@ -273,9 +273,9 @@ async def test_run_with_tools_and_reflection() -> None:
273273
index = 0
274274
async for message in agent.run_stream(task="task"):
275275
if isinstance(message, TaskResult):
276-
assert message == result
276+
assert compare_task_results(message, result)
277277
else:
278-
assert message == result.messages[index]
278+
assert compare_messages(message, result.messages[index])
279279
index += 1
280280

281281
# Test state saving and loading.
@@ -363,9 +363,9 @@ async def test_run_with_parallel_tools() -> None:
363363
index = 0
364364
async for message in agent.run_stream(task="task"):
365365
if isinstance(message, TaskResult):
366-
assert message == result
366+
assert compare_task_results(message, result)
367367
else:
368-
assert message == result.messages[index]
368+
assert compare_messages(message, result.messages[index])
369369
index += 1
370370

371371
# Test state saving and loading.
@@ -446,9 +446,9 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
446446
index = 0
447447
async for message in agent.run_stream(task="task"):
448448
if isinstance(message, TaskResult):
449-
assert message == result
449+
assert compare_task_results(message, result)
450450
else:
451-
assert message == result.messages[index]
451+
assert compare_messages(message, result.messages[index])
452452
index += 1
453453

454454
# Test state saving and loading.
@@ -560,9 +560,9 @@ async def test_run_with_workbench() -> None:
560560
index = 0
561561
async for message in agent.run_stream(task="task"):
562562
if isinstance(message, TaskResult):
563-
assert message == result
563+
assert compare_task_results(message, result)
564564
else:
565-
assert message == result.messages[index]
565+
assert compare_messages(message, result.messages[index])
566566
index += 1
567567

568568
# Test state saving and loading.
@@ -779,9 +779,9 @@ async def test_handoffs() -> None:
779779
index = 0
780780
async for message in tool_use_agent.run_stream(task="task"):
781781
if isinstance(message, TaskResult):
782-
assert message == result
782+
assert compare_task_results(message, result)
783783
else:
784-
assert message == result.messages[index]
784+
assert compare_messages(message, result.messages[index])
785785
index += 1
786786

787787

@@ -852,9 +852,9 @@ async def test_handoff_with_tool_call_context() -> None:
852852
index = 0
853853
async for message in tool_use_agent.run_stream(task="task"):
854854
if isinstance(message, TaskResult):
855-
assert message == result
855+
assert compare_task_results(message, result)
856856
else:
857-
assert message == result.messages[index]
857+
assert compare_messages(message, result.messages[index])
858858
index += 1
859859

860860

@@ -927,9 +927,9 @@ def _next_action(action: str) -> str:
927927
index = 0
928928
async for message in tool_use_agent.run_stream(task="task"):
929929
if isinstance(message, TaskResult):
930-
assert message == result
930+
assert compare_task_results(message, result)
931931
else:
932-
assert message == result.messages[index]
932+
assert compare_messages(message, result.messages[index])
933933
index += 1
934934

935935

@@ -1004,9 +1004,9 @@ def _next_action(action: str) -> Dict[str, str]:
10041004
index = 0
10051005
async for message in tool_use_agent.run_stream(task="task"):
10061006
if isinstance(message, TaskResult):
1007-
assert message == result
1007+
assert compare_task_results(message, result)
10081008
else:
1009-
assert message == result.messages[index]
1009+
assert compare_messages(message, result.messages[index])
10101010
index += 1
10111011

10121012

@@ -1161,9 +1161,9 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
11611161
index = 0
11621162
async for message in agent.run_stream(task=messages):
11631163
if isinstance(message, TaskResult):
1164-
assert message == result
1164+
assert compare_task_results(message, result)
11651165
else:
1166-
assert message == result.messages[index]
1166+
assert compare_messages(message, result.messages[index])
11671167
index += 1
11681168

11691169

0 commit comments

Comments
 (0)