Skip to content

Serialization for the functions, fixed issues in agentchat selector_func and candidate_func serialization #6389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import asyncio
import logging
import re
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core import AgentRuntime, Component, ComponentModel, IndividualFunction
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
@@ -30,6 +29,7 @@
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination
from autogen_core.code_executor._func_with_reqs import ImportFromModule

trace_logger = logging.getLogger(TRACE_LOGGER_NAME)

@@ -61,9 +61,9 @@ def __init__(
model_client: ChatCompletionClient,
selector_prompt: str,
allow_repeated_speaker: bool,
selector_func: Optional[SelectorFuncType],
selector_func: Optional[IndividualFunction],
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
candidate_func: Optional[IndividualFunction],
emit_team_events: bool,
model_client_streaming: bool = False,
) -> None:
@@ -84,11 +84,9 @@ def __init__(
self._selector_prompt = selector_prompt
self._previous_speaker: str | None = None
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
self._is_selector_func_async = iscoroutinefunction(self._selector_func)
self._selector_func = selector_func if selector_func else None
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
self._candidate_func = candidate_func if candidate_func else None
self._model_client_streaming = model_client_streaming

async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
@@ -124,36 +122,55 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -

# Use the selector function if provided.
if self._selector_func is not None:
if self._is_selector_func_async:
async_selector_func = cast(AsyncSelectorFunc, self._selector_func)
speaker = await async_selector_func(thread)
else:
sync_selector_func = cast(SyncSelectorFunc, self._selector_func)
speaker = sync_selector_func(thread)
if speaker is not None:
if speaker not in self._participant_names:
raise ValueError(
f"Selector function returned an invalid speaker name: {speaker}. "
f"Expected one of: {self._participant_names}."
)
# Skip the model based selection.
return speaker
args_type = self._selector_func.args_type()
args_fields = args_type.model_fields
if "messages" not in args_fields:
raise ValueError(
f"Selector function '{self._selector_func.name}' must have a parameter named 'messages'. "
f"Found parameters: {list(args_fields.keys())}"
)

try:
# Use run_json which handles both sync and async functions
speaker = await self._selector_func.run_json({"messages": thread})
if speaker is not None:
if speaker not in self._participant_names:
raise ValueError(
f"Selector function returned an invalid speaker name: {speaker}. "
f"Expected one of: {self._participant_names}."
)
# Skip the model based selection.
return speaker

except Exception as e:
trace_logger.warning(f"Error executing selector function: {e}")

# Use the candidate function to filter participants if provided
if self._candidate_func is not None:
if self._is_candidate_func_async:
async_candidate_func = cast(AsyncCandidateFunc, self._candidate_func)
participants = await async_candidate_func(thread)
else:
sync_candidate_func = cast(SyncCandidateFunc, self._candidate_func)
participants = sync_candidate_func(thread)
if not participants:
raise ValueError("Candidate function must return a non-empty list of participant names.")
if not all(p in self._participant_names for p in participants):
args_type = self._candidate_func.args_type()
args_fields = args_type.model_fields
if "messages" not in args_fields:
raise ValueError(
f"Candidate function returned invalid participant names: {participants}. "
f"Expected one of: {self._participant_names}."
f"Candidate function '{self._candidate_func.name}' must have a parameter named 'messages'. "
f"Found parameters: {list(args_fields.keys())}"
)
try:
# Use run_json which handles both sync and async functions
participants = await self._candidate_func.run_json({"messages": thread})
if not participants:
raise ValueError("Candidate function must return a non-empty list of participant names.")
if not all(p in self._participant_names for p in participants):
raise ValueError(
f"Candidate function returned invalid participant names: {participants}. "
f"Expected one of: {self._participant_names}."
)
except Exception as e:
trace_logger.warning(f"Error executing candidate function: {e}")
# Fallback to default participant selection on error
if self._previous_speaker is not None and not self._allow_repeated_speaker:
participants = [p for p in self._participant_names if p != self._previous_speaker]
else:
participants = list(self._participant_names)
else:
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
if self._previous_speaker is not None and not self._allow_repeated_speaker:
@@ -308,7 +325,8 @@ class SelectorGroupChatConfig(BaseModel):
max_turns: int | None = None
selector_prompt: str
allow_repeated_speaker: bool
# selector_func: ComponentModel | None
selector_func: ComponentModel | None = None
candidate_func: ComponentModel | None = None
max_selector_attempts: int = 3
emit_team_events: bool = False
model_client_streaming: bool = False
@@ -481,8 +499,8 @@ def __init__(
""",
allow_repeated_speaker: bool = False,
max_selector_attempts: int = 3,
selector_func: Optional[SelectorFuncType] = None,
candidate_func: Optional[CandidateFuncType] = None,
selector_func: Optional[SelectorFuncType | IndividualFunction] = None,
candidate_func: Optional[CandidateFuncType | IndividualFunction] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
model_client_streaming: bool = False,
@@ -503,9 +521,32 @@ def __init__(
self._selector_prompt = selector_prompt
self._model_client = model_client
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func

# Wrap the selector_func in IndividualFunction if it's a function and not already an IndividualFunction
if selector_func is not None and not isinstance(selector_func, IndividualFunction):
if not callable(selector_func):
raise ValueError("Selector function should be callable")
self._selector_func = IndividualFunction(
selector_func,
description="Function to select the next speaker in a group chat",
name="selector_func",
)
else:
self._selector_func = selector_func

# Wrap the candidate_func in IndividualFunction if it's a function and not already an IndividualFunction
if candidate_func is not None and not isinstance(candidate_func, IndividualFunction):
if not callable(candidate_func):
raise ValueError("Candidate function should be callable")
self._candidate_func = IndividualFunction(
candidate_func,
description="Function to filter candidate speakers in a group chat",
name="candidate_func",
)
else:
self._candidate_func = candidate_func

self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._model_client_streaming = model_client_streaming

def _create_group_chat_manager_factory(
@@ -551,7 +592,8 @@ def _to_config(self) -> SelectorGroupChatConfig:
selector_prompt=self._selector_prompt,
allow_repeated_speaker=self._allow_repeated_speaker,
max_selector_attempts=self._max_selector_attempts,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
selector_func=self._selector_func.dump_component() if self._selector_func else None,
candidate_func=self._candidate_func.dump_component() if self._candidate_func else None,
emit_team_events=self._emit_team_events,
model_client_streaming=self._model_client_streaming,
)
@@ -568,9 +610,12 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
selector_prompt=config.selector_prompt,
allow_repeated_speaker=config.allow_repeated_speaker,
max_selector_attempts=config.max_selector_attempts,
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
# if config.selector_func
# else None,
selector_func=IndividualFunction.load_component(config.selector_func)
if config.selector_func is not None
else None,
candidate_func=IndividualFunction.load_component(config.candidate_func)
if config.candidate_func
else None,
emit_team_events=config.emit_team_events,
model_client_streaming=config.model_client_streaming,
)
4 changes: 4 additions & 0 deletions python/packages/autogen-core/src/autogen_core/__init__.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,8 @@
from ._cache_store import CacheStore, InMemoryStore
from ._cancellation_token import CancellationToken
from ._closure_agent import ClosureAgent, ClosureContext
from ._functions import IndividualFunction
from ._function_utils import get_imports_from_func
from ._component_config import (
Component,
ComponentBase,
@@ -132,4 +134,6 @@
"DropMessage",
"InterventionHandler",
"DefaultInterventionHandler",
"IndividualFunction",
"get_imports_from_func",
]
130 changes: 130 additions & 0 deletions python/packages/autogen-core/src/autogen_core/_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
# Credit to original authors

import ast
import inspect
import textwrap
import typing
from functools import partial
from logging import getLogger
@@ -12,6 +14,7 @@
Dict,
List,
Optional,
Protocol,
Set,
Tuple,
Type,
@@ -25,6 +28,7 @@
from pydantic import BaseModel, Field, TypeAdapter, create_model # type: ignore
from pydantic_core import PydanticUndefined
from typing_extensions import Literal
from .code_executor import Import, Alias, ImportFromModule

logger = getLogger(__name__)

@@ -322,3 +326,129 @@ def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[Ba
fields[param_name] = (type, Field(default=default_value, description=description))

return cast(BaseModel, create_model(name, **fields)) # type: ignore


class AnyCallable(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

def get_imports_from_func(func: AnyCallable) -> list[Import]:
"""Extract global imports actually used by a function.

Args:
func: The function to analyze

Returns:
List of imported modules that are actually used within the function body
"""
# Get the source code of the module and function
module = inspect.getmodule(func)
if module is None:
return []
module_source = inspect.getsource(module)

# Get function source and dedent it
func_source = inspect.getsource(func)
func_source = textwrap.dedent(func_source)

# Parse the source code into ASTs
module_ast = ast.parse(module_source)
func_ast = ast.parse(func_source)

# Find the function definition node in the parsed tree
function_def = None
for node in ast.walk(func_ast):
if isinstance(node, ast.FunctionDef) and hasattr(func, "__name__") and node.name == func.__name__:
function_def = node
break

if not function_def:
return []

# Collect names used within the function
used_names: set[ast.FunctionDef] = set()

class NameVisitor(ast.NodeVisitor):
def visit_Name(self, node): # type: ignore
if isinstance(node.ctx, ast.Load):
used_names.add(node.id) # type: ignore
self.generic_visit(node)

def visit_Attribute(self, node): # type: ignore
# Check for module.attribute pattern
if isinstance(node.value, ast.Name):
used_names.add(node.value.id) # type: ignore
self.generic_visit(node)

name_visitor = NameVisitor()
for node in function_def.body:
name_visitor.visit(node)

# Visit function argument annotations
for arg in function_def.args.args:
if arg.annotation:
name_visitor.visit(arg.annotation)

# Visit function return type annotation
if function_def.returns:
name_visitor.visit(function_def.returns)

# Extract all import statements from the module - including nested blocks
all_imports = {} # Map from alias to original module
from_imports = {} # Map from alias to (module, imported_name)

class ImportVisitor(ast.NodeVisitor):
def visit_Import(self, node): # type: ignore
for name in node.names:
if name.asname:
all_imports[name.asname] = name.name
else:
all_imports[name.name] = name.name
self.generic_visit(node)

def visit_ImportFrom(self, node): # type: ignore
for name in node.names:
if name.asname:
from_imports[name.asname] = (node.module, name.name)
else:
from_imports[name.name] = (node.module, name.name)
self.generic_visit(node)

# Visit all nodes to find imports everywhere in the module
import_visitor = ImportVisitor()
import_visitor.visit(module_ast)

# Filter imports that are actually used in the function
result: list[Import] = []

# Check regular imports
for used_name in used_names:
if used_name in all_imports:
if all_imports[used_name] == used_name:
# Regular import (import X)
result.append(used_name) # type: ignore
else:
# Aliased import (import X as Y)
result.append(Alias(name=all_imports[used_name], alias=used_name)) # type: ignore

# Process from-imports
from_modules: dict[str, list[str]] = {}

for used_name in used_names:
if used_name in from_imports:
module, name = from_imports[used_name] # type: ignore
if module not in from_modules:
from_modules[module] = []

if name == used_name:
# Regular from-import (from X import Y)
from_modules[module].append(name) # type: ignore
else:
# Aliased from-import (from X import Y as Z)
from_modules[module].append(Alias(name=name, alias=used_name)) # type: ignore

# Add from-imports to result
for module, imports in from_modules.items():
if imports: # Only add if there are actually used imports
result.append(ImportFromModule(module=module, imports=imports)) # type: ignore

return result
Loading
Oops, something went wrong.