Skip to content

Commit b31b4e5

Browse files
ekzhuCopilot
andauthored
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in GraphFlow. ```python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination from autogen_agentchat.teams import DiGraphBuilder, GraphFlow from autogen_ext.models.openai import OpenAIChatCompletionClient async def main(): # Initialize agents with OpenAI model clients. model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") agent_a = AssistantAgent( "A", model_client=model_client, system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.", ) agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.") agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.") # Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise). builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) # Create conditions as callables that check the message content. builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text()) builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text()) graph = builder.build() # Create a GraphFlow team with the directed graph. team = GraphFlow( participants=[agent_a, agent_b, agent_c], graph=graph, termination_condition=MaxMessageTermination(5), ) # Run the team and print the events. async for event in team.run_stream(task="AutoGen is a framework for building AI agents."): print(event) asyncio.run(main()) ``` --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
1 parent 9065c6f commit b31b4e5

File tree

4 files changed

+244
-54
lines changed

4 files changed

+244
-54
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import asyncio
22
from collections import Counter, deque
3-
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set
3+
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
44

55
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
6-
from pydantic import BaseModel
6+
from pydantic import BaseModel, Field, model_validator
77
from typing_extensions import Self
88

99
from autogen_agentchat.agents import BaseChatAgent
@@ -34,16 +34,50 @@ class DiGraphEdge(BaseModel):
3434
3535
This is an experimental feature, and the API will change in the future releases.
3636
37+
.. warning::
38+
39+
If the condition is a callable, it will not be serialized in the model.
40+
3741
"""
3842

3943
target: str # Target node name
40-
condition: str | None = None # Optional execution condition (trigger-based)
44+
condition: Union[str, Callable[[BaseChatMessage], bool], None] = Field(default=None)
4145
"""(Experimental) Condition to execute this edge.
42-
If None, the edge is unconditional. If a string, the edge is conditional on the presence of that string in the last agent chat message.
43-
NOTE: This is an experimental feature WILL change in the future releases to allow for better spcification of branching conditions
44-
similar to the `TerminationCondition` class.
46+
If None, the edge is unconditional.
47+
If a string, the edge is conditional on the presence of that string in the last agent chat message.
48+
If a callable, the edge is conditional on the callable returning True when given the last message.
4549
"""
4650

51+
# Using Field to exclude the condition in serialization if it's a callable
52+
condition_function: Callable[[BaseChatMessage], bool] | None = Field(default=None, exclude=True)
53+
54+
@model_validator(mode="after")
55+
def _validate_condition(self) -> "DiGraphEdge":
56+
# Store callable in a separate field and set condition to None for serialization
57+
if callable(self.condition):
58+
self.condition_function = self.condition
59+
# For serialization purposes, we'll set the condition to None
60+
# when storing as a pydantic model/dict
61+
object.__setattr__(self, "condition", None)
62+
return self
63+
64+
def check_condition(self, message: BaseChatMessage) -> bool:
65+
"""Check if the edge condition is satisfied for the given message.
66+
67+
Args:
68+
message: The message to check the condition against.
69+
70+
Returns:
71+
True if condition is satisfied (None condition always returns True),
72+
False otherwise.
73+
"""
74+
if self.condition_function is not None:
75+
return self.condition_function(message)
76+
elif isinstance(self.condition, str):
77+
# If it's a string, check if the string is in the message content
78+
return self.condition in message.to_model_text()
79+
return True # None condition is always satisfied
80+
4781

4882
class DiGraphNode(BaseModel):
4983
"""Represents a node (agent) in a :class:`DiGraph`, with its outgoing edges and activation type.
@@ -125,7 +159,7 @@ def dfs(node_name: str) -> bool:
125159
cycle_edges: List[DiGraphEdge] = []
126160
for n in cycle_nodes:
127161
cycle_edges.extend(self.nodes[n].edges)
128-
if not any(edge.condition for edge in cycle_edges):
162+
if not any(edge.condition is not None for edge in cycle_edges):
129163
raise ValueError(
130164
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
131165
)
@@ -164,7 +198,7 @@ def graph_validate(self) -> None:
164198
# Outgoing edge condition validation (per node)
165199
for node in self.nodes.values():
166200
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
167-
has_condition = any(edge.condition for edge in node.edges)
201+
has_condition = any(edge.condition is not None for edge in node.edges)
168202
has_unconditioned = any(edge.condition is None for edge in node.edges)
169203
if has_condition and has_unconditioned:
170204
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
@@ -239,11 +273,11 @@ async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseCh
239273
return
240274
assert isinstance(message, BaseChatMessage)
241275
source = message.source
242-
content = message.to_model_text()
243276

244277
# Propagate the update to the children of the node.
245278
for edge in self._edges[source]:
246-
if edge.condition and edge.condition not in content:
279+
# Use the new check_condition method that handles both string and callable conditions
280+
if not edge.check_condition(message):
247281
continue
248282
if self._activation[edge.target] == "all":
249283
self._remaining[edge.target] -= 1
@@ -360,6 +394,11 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
360394
See the :class:`DiGraphBuilder` documentation for more details.
361395
The :class:`GraphFlow` class is designed to be used with the :class:`DiGraphBuilder` for creating complex workflows.
362396
397+
.. warning::
398+
399+
When using callable conditions in edges, they will not be serialized
400+
when calling :meth:`dump_component`. This will be addressed in future releases.
401+
363402
364403
Args:
365404
participants (List[ChatAgent]): The participants in the group chat.
@@ -450,7 +489,7 @@ async def main():
450489
451490
asyncio.run(main())
452491
453-
**Conditional Branching: A → B (if 'yes') or C (if 'no')**
492+
**Conditional Branching: A → B (if 'yes') or C (otherwise)**
454493
455494
.. code-block:: python
456495
@@ -473,11 +512,12 @@ async def main():
473512
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
474513
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
475514
476-
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no").
515+
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
477516
builder = DiGraphBuilder()
478517
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
479-
builder.add_edge(agent_a, agent_b, condition="yes")
480-
builder.add_edge(agent_a, agent_c, condition="no")
518+
# Create conditions as callables that check the message content.
519+
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
520+
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
481521
graph = builder.build()
482522
483523
# Create a GraphFlow team with the directed graph.
@@ -494,7 +534,7 @@ async def main():
494534
495535
asyncio.run(main())
496536
497-
**Loop with exit condition: A → B → C (if 'APPROVE') or A (if 'REJECT')**
537+
**Loop with exit condition: A → B → C (if 'APPROVE') or A (otherwise)**
498538
499539
.. code-block:: python
500540
@@ -518,17 +558,21 @@ async def main():
518558
"B",
519559
model_client=model_client,
520560
system_message="Provide feedback on the input, if your feedback has been addressed, "
521-
"say 'APPROVE', else say 'REJECT' and provide a reason.",
561+
"say 'APPROVE', otherwise provide a reason for rejection.",
522562
)
523563
agent_c = AssistantAgent(
524564
"C", model_client=model_client, system_message="Translate the final product to Korean."
525565
)
526566
527-
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A ("REJECT").
567+
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A (otherwise).
528568
builder = DiGraphBuilder()
529569
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
530570
builder.add_edge(agent_a, agent_b)
531-
builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a})
571+
572+
# Create conditional edges using strings
573+
builder.add_edge(agent_b, agent_c, condition=lambda msg: "APPROVE" in msg.to_model_text())
574+
builder.add_edge(agent_b, agent_a, condition=lambda msg: "APPROVE" not in msg.to_model_text())
575+
532576
builder.set_entry_point(agent_a)
533577
graph = builder.build()
534578

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_graph_builder.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Dict, Literal, Optional, Union
1+
import warnings
2+
from typing import Callable, Dict, Literal, Optional, Union
23

34
from autogen_agentchat.base import ChatAgent
5+
from autogen_agentchat.messages import BaseChatMessage
46

57
from ._digraph_group_chat import DiGraph, DiGraphEdge, DiGraphNode
68

@@ -22,7 +24,7 @@ class DiGraphBuilder:
2224
- Cyclic loops with safe exits
2325
2426
Each node in the graph represents an agent. Edges define execution paths between agents,
25-
and can optionally be conditioned on message content.
27+
and can optionally be conditioned on message content using callable functions.
2628
2729
The builder is compatible with the `Graph` runner and supports both standard and filtered agents.
2830
@@ -49,16 +51,29 @@ class DiGraphBuilder:
4951
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
5052
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
5153
52-
Example — Conditional Branching A → B ("yes"), A → C ("no"):
54+
Example — Conditional Branching A → B or A → C:
5355
>>> builder = GraphBuilder()
5456
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
55-
>>> builder.add_conditional_edges(agent_a, {"yes": agent_b, "no": agent_c})
57+
>>> # Add conditional edges using keyword check
58+
>>> builder.add_edge(agent_a, agent_b, condition="keyword1")
59+
>>> builder.add_edge(agent_a, agent_c, condition="keyword2")
5660
57-
Example — Loop: A → B → A ("loop"), B → C ("exit"):
61+
62+
Example — Using Custom String Conditions:
63+
>>> builder = GraphBuilder()
64+
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
65+
>>> # Add condition strings to check in messages
66+
>>> builder.add_edge(agent_a, agent_b, condition="big")
67+
>>> builder.add_edge(agent_a, agent_c, condition="small")
68+
69+
Example — Loop: A → B → A or B → C:
5870
>>> builder = GraphBuilder()
5971
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
6072
>>> builder.add_edge(agent_a, agent_b)
61-
>>> builder.add_conditional_edges(agent_b, {"loop": agent_a, "exit": agent_c})
73+
>> # Add a loop back to agent A
74+
>>> builder.add_edge(agent_b, agent_a, condition=lambda msg: "loop" in msg.to_model_text())
75+
>>> # Add exit condition to break the loop
76+
>>> builder.add_edge(agent_b, agent_c, condition=lambda msg: "loop" not in msg.to_model_text())
6277
"""
6378

6479
def __init__(self) -> None:
@@ -78,9 +93,26 @@ def add_node(self, agent: ChatAgent, activation: Literal["all", "any"] = "all")
7893
return self
7994

8095
def add_edge(
81-
self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[str] = None
96+
self,
97+
source: Union[str, ChatAgent],
98+
target: Union[str, ChatAgent],
99+
condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None,
82100
) -> "DiGraphBuilder":
83-
"""Add a directed edge from source to target, optionally with a condition."""
101+
"""Add a directed edge from source to target, optionally with a condition.
102+
103+
Args:
104+
source: Source node (agent name or agent object)
105+
target: Target node (agent name or agent object)
106+
condition: Optional condition for edge activation.
107+
If string, activates when substring is found in message.
108+
If callable, activates when function returns True for the message.
109+
110+
Returns:
111+
Self for method chaining
112+
113+
Raises:
114+
ValueError: If source or target node doesn't exist in the builder
115+
"""
84116
source_name = self._get_name(source)
85117
target_name = self._get_name(target)
86118

@@ -95,9 +127,35 @@ def add_edge(
95127
def add_conditional_edges(
96128
self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]]
97129
) -> "DiGraphBuilder":
98-
"""Add multiple conditional edges from a source node based on condition strings."""
99-
for condition, target in condition_to_target.items():
100-
self.add_edge(source, target, condition)
130+
"""Add multiple conditional edges from a source node based on keyword checks.
131+
132+
.. warning::
133+
134+
This method interface will be changed in the future to support callable conditions.
135+
Please use `add_edge` if you need to specify custom conditions.
136+
137+
Args:
138+
source: Source node (agent name or agent object)
139+
condition_to_target: Mapping from condition strings to target nodes
140+
Each key is a keyword that will be checked in the message content
141+
Each value is the target node to activate when condition is met
142+
143+
For each key (keyword), a lambda will be created that checks
144+
if the keyword is in the message text.
145+
146+
Returns:
147+
Self for method chaining
148+
"""
149+
150+
warnings.warn(
151+
"add_conditional_edges will be changed in the future to support callable conditions. "
152+
"For now, please use add_edge if you need to specify custom conditions.",
153+
DeprecationWarning,
154+
stacklevel=2,
155+
)
156+
157+
for condition_keyword, target in condition_to_target.items():
158+
self.add_edge(source, target, condition=condition_keyword)
101159
return self
102160

103161
def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder":

0 commit comments

Comments
 (0)