Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/agentex/lib/adk/_modules/acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def create_task(
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=5),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
request: dict[str, Any] | None = None,
) -> Task:
"""
Create a new task.
Expand All @@ -71,6 +72,7 @@ async def create_task(
start_to_close_timeout: The start to close timeout for the task.
heartbeat_timeout: The heartbeat timeout for the task.
retry_policy: The retry policy for the task.
request: Additional request context including headers to forward to the agent.

Returns:
The task entry.
Expand All @@ -85,6 +87,7 @@ async def create_task(
params=params,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
),
response_type=Task,
start_to_close_timeout=start_to_close_timeout,
Expand All @@ -99,6 +102,7 @@ async def create_task(
params=params,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
)

async def send_event(
Expand All @@ -112,15 +116,22 @@ async def send_event(
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=5),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
request: dict[str, Any] | None = None,
) -> Event:
"""
Send an event to a task.

Args:
task_id: The ID of the task to send the event to.
data: The data to send to the event.
content: The content to send to the event.
agent_id: The ID of the agent to send the event to.
agent_name: The name of the agent to send the event to.
trace_id: The trace ID for the event.
parent_span_id: The parent span ID for the event.
start_to_close_timeout: The start to close timeout for the event.
heartbeat_timeout: The heartbeat timeout for the event.
retry_policy: The retry policy for the event.
request: Additional request context including headers to forward to the agent.

Returns:
The event entry.
Expand All @@ -135,6 +146,7 @@ async def send_event(
content=content,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
),
response_type=None,
start_to_close_timeout=start_to_close_timeout,
Expand All @@ -149,6 +161,7 @@ async def send_event(
content=content,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
)

async def send_message(
Expand All @@ -162,15 +175,22 @@ async def send_message(
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=5),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
request: dict[str, Any] | None = None,
) -> List[TaskMessage]:
"""
Send a message to a task.

Args:
task_id: The ID of the task to send the message to.
content: The task message content to send to the task.
task_id: The ID of the task to send the message to.
agent_id: The ID of the agent to send the message to.
agent_name: The name of the agent to send the message to.
trace_id: The trace ID for the message.
parent_span_id: The parent span ID for the message.
start_to_close_timeout: The start to close timeout for the message.
heartbeat_timeout: The heartbeat timeout for the message.
retry_policy: The retry policy for the message.
request: Additional request context including headers to forward to the agent.

Returns:
The message entry.
Expand All @@ -185,6 +205,7 @@ async def send_message(
content=content,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
),
response_type=TaskMessage,
start_to_close_timeout=start_to_close_timeout,
Expand All @@ -199,6 +220,7 @@ async def send_message(
content=content,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
)

async def cancel_task(
Expand All @@ -212,6 +234,7 @@ async def cancel_task(
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=5),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
request: dict[str, Any] | None = None,
) -> Task:
"""
Cancel a task by sending cancel request to the agent that owns the task.
Expand All @@ -226,6 +249,7 @@ async def cancel_task(
start_to_close_timeout: The start to close timeout for the task.
heartbeat_timeout: The heartbeat timeout for the task.
retry_policy: The retry policy for the task.
request: Additional request context including headers to forward to the agent.

Returns:
The task entry.
Expand All @@ -244,6 +268,7 @@ async def cancel_task(
agent_name=agent_name,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
),
response_type=None,
start_to_close_timeout=start_to_close_timeout,
Expand All @@ -258,4 +283,5 @@ async def cancel_task(
agent_name=agent_name,
trace_id=trace_id,
parent_span_id=parent_span_id,
request=request,
)
74 changes: 62 additions & 12 deletions src/agentex/lib/core/services/adk/acp/acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from agentex.types.task_message import TaskMessage
from agentex.types.task_message_content import TaskMessageContent
from agentex.types.task_message_content_param import TaskMessageContentParam
from agentex.types.agent_rpc_params import (
ParamsCancelTaskRequest as RpcParamsCancelTaskRequest,
ParamsSendEventRequest as RpcParamsSendEventRequest,
)

logger = make_logger(__name__)

Expand All @@ -30,6 +34,7 @@ async def task_create(
params: dict[str, Any] | None = None,
trace_id: str | None = None,
parent_span_id: str | None = None,
request: dict[str, Any] | None = None,
) -> Task:
trace = self._tracer.trace(trace_id=trace_id)
async with trace.span(
Expand All @@ -43,6 +48,10 @@ async def task_create(
},
) as span:
heartbeat_if_in_workflow("task create")

# Extract headers from request; pass-through to agent
extra_headers = request.get("headers") if request else None

if agent_name:
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
agent_name=agent_name,
Expand All @@ -51,6 +60,7 @@ async def task_create(
"name": name,
"params": params,
},
extra_headers=extra_headers,
)
elif agent_id:
json_rpc_response = await self._agentex_client.agents.rpc(
Expand All @@ -60,6 +70,7 @@ async def task_create(
"name": name,
"params": params,
},
extra_headers=extra_headers,
)
else:
raise ValueError("Either agent_name or agent_id must be provided")
Expand All @@ -78,6 +89,7 @@ async def message_send(
task_name: str | None = None,
trace_id: str | None = None,
parent_span_id: str | None = None,
request: dict[str, Any] | None = None,
) -> List[TaskMessage]:
trace = self._tracer.trace(trace_id=trace_id)
async with trace.span(
Expand All @@ -92,6 +104,10 @@ async def message_send(
},
) as span:
heartbeat_if_in_workflow("message send")

# Extract headers from request; pass-through to agent
extra_headers = request.get("headers") if request else None

if agent_name:
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
agent_name=agent_name,
Expand All @@ -101,6 +117,7 @@ async def message_send(
"content": cast(TaskMessageContentParam, content.model_dump()),
"stream": False,
},
extra_headers=extra_headers,
)
elif agent_id:
json_rpc_response = await self._agentex_client.agents.rpc(
Expand All @@ -111,12 +128,13 @@ async def message_send(
"content": cast(TaskMessageContentParam, content.model_dump()),
"stream": False,
},
extra_headers=extra_headers,
)
else:
raise ValueError("Either agent_name or agent_id must be provided")

task_messages: List[TaskMessage] = []
logger.info(f"json_rpc_response: {json_rpc_response}")
logger.info("json_rpc_response: %s", json_rpc_response)
if isinstance(json_rpc_response.result, list):
for message in json_rpc_response.result:
task_message = TaskMessage.model_validate(message)
Expand All @@ -137,6 +155,7 @@ async def event_send(
task_name: str | None = None,
trace_id: str | None = None,
parent_span_id: str | None = None,
request: dict[str, Any] | None = None,
) -> Event:
trace = self._tracer.trace(trace_id=trace_id)
async with trace.span(
Expand All @@ -146,27 +165,33 @@ async def event_send(
"agent_id": agent_id,
"agent_name": agent_name,
"task_id": task_id,
"task_name": task_name,
"content": content,
},
) as span:
heartbeat_if_in_workflow("event send")

# Extract headers from request; pass-through to agent
extra_headers = request.get("headers") if request else None

rpc_event_params: RpcParamsSendEventRequest = {
"task_id": task_id,
"task_name": task_name,
"content": cast(TaskMessageContentParam, content.model_dump()),
}
if agent_name:
json_rpc_response = await self._agentex_client.agents.rpc_by_name(
agent_name=agent_name,
method="event/send",
params={
"task_id": task_id,
"content": cast(TaskMessageContentParam, content.model_dump()),
},
params=rpc_event_params,
extra_headers=extra_headers,
)
elif agent_id:
json_rpc_response = await self._agentex_client.agents.rpc(
agent_id=agent_id,
method="event/send",
params={
"task_id": task_id,
"content": cast(TaskMessageContentParam, content.model_dump()),
},
params=rpc_event_params,
extra_headers=extra_headers,
)
else:
raise ValueError("Either agent_name or agent_id must be provided")
Expand All @@ -184,15 +209,34 @@ async def task_cancel(
agent_name: str | None = None,
trace_id: str | None = None,
parent_span_id: str | None = None,
) -> Task:
request: dict[str, Any] | None = None,
) -> Task:
"""
Cancel a task by sending cancel request to the agent that owns the task.

Args:
task_id: ID of the task to cancel (passed to agent in params)
task_name: Name of the task to cancel (passed to agent in params)
agent_id: ID of the agent that owns the task
agent_name: Name of the agent that owns the task
trace_id: Trace ID for tracing
parent_span_id: Parent span ID for tracing
request: Additional request context including headers to forward to the agent

Returns:
Task entry representing the cancelled task

Raises:
ValueError: If neither agent_name nor agent_id is provided,
or if neither task_name nor task_id is provided
"""
# Require agent identification
if not agent_name and not agent_id:
raise ValueError("Either agent_name or agent_id must be provided to identify the agent that owns the task")

# Require task identification
if not task_name and not task_id:
raise ValueError("Either task_name or task_id must be provided to identify the task to cancel")

trace = self._tracer.trace(trace_id=trace_id)
async with trace.span(
parent_id=parent_span_id,
Expand All @@ -206,8 +250,11 @@ async def task_cancel(
) as span:
heartbeat_if_in_workflow("task cancel")

# Extract headers from request; pass-through to agent
extra_headers = request.get("headers") if request else None

# Build params for the agent (task identification)
params = {}
params: RpcParamsCancelTaskRequest = {}
if task_id:
params["task_id"] = task_id
if task_name:
Expand All @@ -219,12 +266,15 @@ async def task_cancel(
agent_name=agent_name,
method="task/cancel",
params=params,
extra_headers=extra_headers,
)
else: # agent_id is provided (validated above)
assert agent_id is not None
json_rpc_response = await self._agentex_client.agents.rpc(
agent_id=agent_id,
method="task/cancel",
params=params,
extra_headers=extra_headers,
)

task_entry = Task.model_validate(json_rpc_response.result)
Expand Down
Loading
Loading