In [1]:
from abc import ABC
from inspect import signature
from typing import Annotated, Any, Callable, Optional, Union

from fast_depends import Depends as FastDepends
from fast_depends import inject
from pydantic import BaseModel

from autogen.tools import Tool

In [2]:
class BaseContext(ABC):
    pass


class ChatContext(BaseContext):
    messages: list[str] = []


def Depends(x: Any) -> Any:
    if isinstance(x, BaseContext):
        return FastDepends(lambda: x)

    return FastDepends(x)


class Agent:
    def __init__(self):
        self.tools: dict[str, Callable[..., Any]] = {}

    def _remove_injected_params_from_signature(self, func: Callable[..., Any]) -> None:
        remove_from_signature = []
        sig = signature(func)
        for param in sig.parameters.values():
            param_annotation = (
                param.annotation.__args__[0] if hasattr(param.annotation, "__args__") else param.annotation
            )
            if isinstance(param_annotation, type) and issubclass(param_annotation, BaseContext):
                remove_from_signature.append(param.name)

        new_signature = sig.replace(
            parameters=[p for p in sig.parameters.values() if p.name not in remove_from_signature]
        )
        func.__signature__ = new_signature

    # Coopied from ConversableAgent

    def _create_tool_if_needed(
        self, func_or_tool: Union[Tool, Callable[..., Any]], name: Optional[str], description: Optional[str]
    ) -> Tool:

        if isinstance(func_or_tool, Tool):
            tool: Tool = func_or_tool
            tool._name = name or tool.name
            tool._description = description or tool.description

            return tool

        if isinstance(func_or_tool, Callable):
            # Only next 2 lines are different from the original
            func: Callable[..., Any] = inject(func_or_tool)
            self._remove_injected_params_from_signature(func)

            name = name or func.__name__

            tool = Tool(name=name, description=description, func=func)

            return tool

        raise ValueError(
            "Parameter 'func_or_tool' must be a function or a Tool instance, it is '{type(func_or_tool)}' instead."
        )

    def register_for_llm(
        self, *, name: Optional[str] = None, description: Optional[str] = None
    ) -> Callable[[Callable[..., Any]], Tool]:
        def decorator(func_or_tool: Union[Callable[..., Any], Tool]) -> Tool:
            nonlocal name, description

            tool = self._create_tool_if_needed(func_or_tool, name, description)

            return tool

        return decorator

    def register_for_execution(self, name: Optional[str] = None) -> Callable[[Tool], Tool]:
        def decorator(func_or_tool: Tool) -> Tool:
            nonlocal name

            tool = self._create_tool_if_needed(func_or_tool, name, None)

            self.tools[tool.name] = tool.func

            return tool

        return decorator

## Context injection

#### Title:
As a developer, I want to assign a default context value for function parameters to streamline dependency injection in tool definitions.

#### Description:
This feature allows developers to define functions with parameters that include default values derived from a context object. The context object can be instantiated and injected without manual handling, simplifying API usage.

#### Acceptance Criteria:
- Functions can specify a default context value as part of their signature.
- Context values are automatically injected during function execution.
- The default context parameter must be compatible with the defined BaseContext.


In [3]:
class MyContext(BaseContext, BaseModel):
    b: int


ctx = MyContext(b=2)

agent = Agent()


@agent.register_for_llm(description="Example function")
@agent.register_for_execution()
def f(
    a: int,
    ctx: Annotated[MyContext, Depends(ctx)],
    # We support the following syntaxes as well:
    # ctx: Annotated[MyContext, Depends(MyContext(b=2))],
    # ctx: MyContext = Depends(MyContext(b=2)), # non-annotated version
    # ctx: MyContext = MyContext(b=2), # non-annotated version for subclasses of BaseContext
) -> int:
    return a + ctx.b

In [4]:
assert "f" in agent.tools
assert isinstance(agent.tools["f"], Callable)
assert str(signature(agent.tools["f"])) == "(a: int) -> int"

In [5]:
assert f(1) == 3
ctx.b = 4
assert f(1) == 5

## Different agent context in each function invocation

#### Title:
As a developer, I want to inject context objects dynamically into functions without explicitly modifying their signatures.

#### Description:
This feature allows context objects to be dynamically assigned to functions during registration. The function signature remains unchanged, and the context is removed from the interface seen by LLMs or external users.

#### Acceptance Criteria:
- Context injection occurs transparently during function execution.
- Registered functions do not expose the ctx parameter in their final API interface.
- Functions retain full type safety and compatibility with the BaseContext.
- Unit tests validate dynamic injection scenarios.


In [6]:
agent1 = Agent()
agent2 = Agent()
user_proxy = Agent()


def _f(
    a: int,
    ctx: MyContext,
) -> int:
    ctx.b = ctx.b + 1
    return a + ctx.b


@user_proxy.register_for_execution()
@agent1.register_for_llm(description="Example function")
def f1(
    a: int,
    ctx: Annotated[MyContext, Depends(MyContext(b=2))],
) -> int:
    return _f(a, ctx)


@user_proxy.register_for_execution()
@agent1.register_for_llm(description="Example function")
def f2(
    a: int,
    ctx: Annotated[MyContext, Depends(MyContext(b=2))],
) -> int:
    return _f(a, ctx)

In [7]:
assert "f1" in user_proxy.tools
assert "f2" in user_proxy.tools
assert isinstance(user_proxy.tools["f1"], Callable)
assert isinstance(user_proxy.tools["f2"], Callable)

In [8]:
assert f1(1) == 4
assert f1(1) == 5
assert f2(1) == 4

## User Story 3

#### Title:
As a developer, I want multiple agents to share a single context object for efficient and synchronized execution.
#### Description:
Enable multiple agents or users to share the same context object. This allows for coordinated interactions and consistency in function behavior across users while minimizing resource overhead.
#### Acceptance Criteria:
- A shared context object can be assigned to multiple users.
- Shared context updates reflect immediately across all agents.
- Unit tests validate behavior with shared context across agents.


In [9]:
agent1 = Agent()
agent2 = Agent()
user_proxy = Agent()

ctx = MyContext(b=2)


def _f(
    a: int,
    ctx: MyContext,
) -> int:
    ctx.b = ctx.b + 1
    return a + ctx.b


@user_proxy.register_for_execution()
@agent1.register_for_llm(description="Example function")
def f1(
    a: int,
    ctx: Annotated[MyContext, Depends(ctx)],
) -> int:
    return _f(a, ctx)


@user_proxy.register_for_execution()
@agent2.register_for_llm(description="Example function")
def f2(
    a: int,
    ctx: Annotated[MyContext, Depends(ctx)],
) -> int:
    return _f(a, ctx)

In [10]:
assert "f1" in user_proxy.tools
assert "f2" in user_proxy.tools
assert isinstance(user_proxy.tools["f1"], Callable)
assert isinstance(user_proxy.tools["f2"], Callable)
assert str(signature(user_proxy.tools["f1"])) == "(a: int) -> int"
assert str(signature(user_proxy.tools["f2"])) == "(a: int) -> int"
assert user_proxy.tools["f1"](1) == 4
assert user_proxy.tools["f1"](1) == 5
assert user_proxy.tools["f2"](1) == 6
assert user_proxy.tools["f2"](1) == 7
assert user_proxy.tools["f1"](1) == 8

### Test Signature and context injection

In [11]:
agent = Agent()
user_proxy = Agent()
ctx = MyContext(b=2)


@user_proxy.register_for_execution()
@agent.register_for_llm(description="Example function")
def test_1(
    a: int,
    ctx: Annotated[MyContext, Depends(ctx)],
) -> int:
    return a + ctx.b


assert "test_1" in user_proxy.tools
assert isinstance(user_proxy.tools["test_1"], Callable)
assert str(signature(user_proxy.tools["test_1"])) == "(a: int) -> int"
assert user_proxy.tools["test_1"](1) == 3

In [12]:
agent = Agent()
user_proxy = Agent()


@user_proxy.register_for_execution()
@agent.register_for_llm(description="Example function")
def test_2(
    a: int,
    ctx: MyContext = Depends(MyContext(b=3)),
) -> int:
    return a + ctx.b


assert "test_2" in user_proxy.tools
assert isinstance(user_proxy.tools["test_2"], Callable)
assert str(signature(user_proxy.tools["test_2"])) == "(a: int) -> int"
assert user_proxy.tools["test_2"](1) == 4

In [13]:
agent = Agent()
user_proxy = Agent()


@user_proxy.register_for_execution()
@agent.register_for_llm(description="Example function")
def test_3(
    a: int,
    ctx: MyContext = MyContext(b=4),
) -> int:
    return a + ctx.b


assert "test_3" in user_proxy.tools
assert isinstance(user_proxy.tools["test_3"], Callable)
assert str(signature(user_proxy.tools["test_3"])) == "(a: int) -> int"
assert user_proxy.tools["test_3"](1) == 5