Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/django_ai_core/contrib/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import inspect
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Annotated, get_args, get_origin
from typing import Annotated, Callable, TypeVar, get_args, get_origin

from django.core.exceptions import ValidationError
from django.core.validators import validate_slug

from .permissions import BasePermission
from .views import AgentExecutionView

AgentT = TypeVar("AgentT", bound="Agent")


@dataclass
class AgentParameter:
Expand Down Expand Up @@ -79,15 +81,20 @@ class AgentRegistry:
def __init__(self):
self._agents: dict[str, type[Agent]] = {}

def register(self):
"""Decorator to register an agent."""

def decorator(cls: type[Agent]) -> type[Agent]:
agent_slug = cls.slug
self._agents[agent_slug] = cls
return cls

return decorator
def register(
self, cls: type[AgentT] | None = None
) -> type[AgentT] | Callable[[type[AgentT]], type[AgentT]]:
def decorator(agent_cls: type[AgentT]) -> type[AgentT]:
agent_slug = agent_cls.slug
self._agents[agent_slug] = agent_cls
return agent_cls

if cls is None:
# Called with parentheses: @registry.register()
return decorator
else:
# Called without parentheses: @registry.register
return decorator(cls)

def get(self, slug: str) -> type[Agent]:
if slug not in self._agents:
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/contrib/agents/test_agent_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def test_agent_registry_register_uses_agent_slug():
assert registry._agents["test-one"] is TestAgentOne


def test_agent_registry_register_without_parentheses():
"""Test that register can be called without parentheses."""
registry = AgentRegistry()

decorated = registry.register(TestAgentOne)
assert decorated is TestAgentOne
assert "test-one" in registry._agents
assert registry._agents["test-one"] is TestAgentOne


def test_agent_registry_register_multiple_agents():
"""Test registering multiple agents with the registry."""
registry = AgentRegistry()
Expand Down
Loading