Skip to content

Commit c99aa74

Browse files
authored
Fix graph validation logic and add tests (#6630)
Follow up to #6629
1 parent 1b32eb6 commit c99aa74

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,10 @@ def graph_validate(self) -> None:
198198
# Outgoing edge condition validation (per node)
199199
for node in self.nodes.values():
200200
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
201-
has_condition = any(edge.condition is not None for edge in node.edges)
202-
has_unconditioned = any(edge.condition is None for edge in node.edges)
201+
has_condition = any(
202+
edge.condition is not None or edge.condition_function is not None for edge in node.edges
203+
)
204+
has_unconditioned = any(edge.condition is None and edge.condition_function is None for edge in node.edges)
203205
if has_condition and has_unconditioned:
204206
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
205207

python/packages/autogen-agentchat/tests/test_group_chat_graph.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,19 @@ def test_validate_graph_success() -> None:
259259
graph.graph_validate()
260260
assert not graph.get_has_cycles()
261261

262+
# Use a lambda condition
263+
graph_with_lambda = DiGraph(
264+
nodes={
265+
"A": DiGraphNode(
266+
name="A", edges=[DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text())]
267+
),
268+
"B": DiGraphNode(name="B", edges=[]),
269+
}
270+
)
271+
# No error should be raised
272+
graph_with_lambda.graph_validate()
273+
assert not graph_with_lambda.get_has_cycles()
274+
262275

263276
def test_validate_graph_missing_start_node() -> None:
264277
"""Test validation failure when no start node exists."""
@@ -298,6 +311,23 @@ def test_validate_graph_mixed_conditions() -> None:
298311
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
299312
graph.graph_validate()
300313

314+
# Use lambda for condition
315+
graph_with_lambda = DiGraph(
316+
nodes={
317+
"A": DiGraphNode(
318+
name="A",
319+
edges=[
320+
DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text()),
321+
DiGraphEdge(target="C"),
322+
],
323+
),
324+
"B": DiGraphNode(name="B", edges=[]),
325+
"C": DiGraphNode(name="C", edges=[]),
326+
}
327+
)
328+
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
329+
graph_with_lambda.graph_validate()
330+
301331

302332
@pytest.mark.asyncio
303333
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
@@ -603,6 +633,29 @@ async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | Non
603633
result = await team.run(task="Trigger yes")
604634
assert result.messages[2].source == "B"
605635

636+
# Use lambda conditions
637+
graph_with_lambda = DiGraph(
638+
nodes={
639+
"A": DiGraphNode(
640+
name="A",
641+
edges=[
642+
DiGraphEdge(target="B", condition=lambda msg: "yes" in msg.to_model_text()),
643+
DiGraphEdge(target="C", condition=lambda msg: "no" in msg.to_model_text()),
644+
],
645+
),
646+
"B": DiGraphNode(name="B", edges=[], activation="any"),
647+
"C": DiGraphNode(name="C", edges=[], activation="any"),
648+
}
649+
)
650+
team_with_lambda = GraphFlow(
651+
participants=[agent_a, agent_b, agent_c],
652+
graph=graph_with_lambda,
653+
runtime=runtime,
654+
termination_condition=MaxMessageTermination(5),
655+
)
656+
result_with_lambda = await team_with_lambda.run(task="Trigger no")
657+
assert result_with_lambda.messages[2].source == "C"
658+
606659

607660
@pytest.mark.asyncio
608661
async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None:
@@ -785,6 +838,31 @@ async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | N
785838
result = await team.run(task="banana")
786839
assert result.messages[2].source == "C"
787840

841+
# Use lambda conditions
842+
graph_with_lambda = DiGraph(
843+
nodes={
844+
"A": DiGraphNode(
845+
name="A",
846+
edges=[
847+
DiGraphEdge(target="B", condition=lambda msg: "apple" in msg.to_model_text()),
848+
DiGraphEdge(target="C", condition=lambda msg: "banana" in msg.to_model_text()),
849+
DiGraphEdge(target="D", condition=lambda msg: "cherry" in msg.to_model_text()),
850+
],
851+
),
852+
"B": DiGraphNode(name="B", edges=[]),
853+
"C": DiGraphNode(name="C", edges=[]),
854+
"D": DiGraphNode(name="D", edges=[]),
855+
}
856+
)
857+
team_with_lambda = GraphFlow(
858+
participants=[agent_a, agent_b, agent_c, agent_d],
859+
graph=graph_with_lambda,
860+
runtime=runtime,
861+
termination_condition=MaxMessageTermination(5),
862+
)
863+
result_with_lambda = await team_with_lambda.run(task="cherry")
864+
assert result_with_lambda.messages[2].source == "D"
865+
788866

789867
class _TestMessageFilterAgentConfig(BaseModel):
790868
name: str

0 commit comments

Comments
 (0)