Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ async def get_state(self, task_id: str) -> WorkflowState:
workflow_id=task_id,
)

async def send_event(self, agent: Agent, task: Task, event: Event) -> None:
async def send_event(self, agent: Agent, task: Task, event: Event, request: dict | None = None) -> None:
return await self._temporal_client.send_signal(
workflow_id=task.id,
signal=SignalName.RECEIVE_EVENT.value,
payload=SendEventParams(
agent=agent,
task=task,
event=event,
request=request,
).model_dump(),
)

Expand Down
7 changes: 5 additions & 2 deletions src/agentex/lib/sdk/fastacp/base/base_acp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,14 @@ async def _handle_jsonrpc(self, request: Request):
),
)

# Extract application headers, excluding sensitive/transport headers per FASTACP_* rules
# Extract application headers using allowlist approach (only x-* headers)
# Matches gateway's security filtering rules
# Forward filtered headers via params.request.headers to agent handlers
custom_headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in FASTACP_HEADER_SKIP_EXACT
if key.lower().startswith("x-")
and key.lower() not in FASTACP_HEADER_SKIP_EXACT
and not any(key.lower().startswith(p) for p in FASTACP_HEADER_SKIP_PREFIXES)
}

Expand All @@ -168,6 +170,7 @@ async def _handle_jsonrpc(self, request: Request):
params_data = dict(rpc_request.params) if rpc_request.params else {}

# Add custom headers to the request structure if any headers were provided
# Gateway sends filtered headers via HTTP, SDK extracts and populates params.request
if custom_headers:
params_data["request"] = {"headers": custom_headers}
params = params_model.model_validate(params_data)
Expand Down
40 changes: 26 additions & 14 deletions src/agentex/lib/sdk/fastacp/base/constants.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
from __future__ import annotations

# Header filtering rules for FastACP server
# These rules match the gateway's security filtering

# Prefixes to skip (case-insensitive beginswith checks)
FASTACP_HEADER_SKIP_PREFIXES: tuple[str, ...] = (
"content-",
# Hop-by-hop headers that should not be forwarded
HOP_BY_HOP_HEADERS: set[str] = {
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
"content-length",
"content-encoding",
"host",
"user-agent",
"x-forwarded-",
"sec-",
)
}

# Exact header names to skip (case-insensitive matching done by lowercasing keys)
FASTACP_HEADER_SKIP_EXACT: set[str] = {
"x-agent-api-key",
"connection",
"accept-encoding",
# Sensitive headers that should never be forwarded
BLOCKED_HEADERS: set[str] = {
"authorization",
"cookie",
"content-length",
"transfer-encoding",
"x-agent-api-key",
}

# Legacy constants for backward compatibility
FASTACP_HEADER_SKIP_EXACT: set[str] = HOP_BY_HOP_HEADERS | BLOCKED_HEADERS

FASTACP_HEADER_SKIP_PREFIXES: tuple[str, ...] = (
"x-forwarded-", # proxy headers
"sec-", # security headers added by browsers
)


1 change: 1 addition & 0 deletions src/agentex/lib/sdk/fastacp/impl/temporal_acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def handle_event_send(params: SendEventParams) -> None:
agent=params.agent,
task=params.task,
event=params.event,
request=params.request,
)

except Exception as e:
Expand Down
228 changes: 227 additions & 1 deletion tests/test_header_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, override
import sys
import types
from datetime import datetime, timezone
from unittest.mock import AsyncMock, Mock

import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -44,8 +46,14 @@ class _StubTracer(_StubAsyncTracer):

from agentex.lib.core.services.adk.acp.acp import ACPService
from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer
from agentex.lib.types.acp import RPCMethod, SendMessageParams
from agentex.lib.types.acp import RPCMethod, SendMessageParams, SendEventParams
from agentex.types.task_message_content import TextContent
from agentex.lib.sdk.fastacp.impl.temporal_acp import TemporalACP
from agentex.lib.core.temporal.services.temporal_task_service import TemporalTaskService
from agentex.lib.environment_variables import EnvironmentVariables
from agentex.types.agent import Agent
from agentex.types.task import Task
from agentex.types.event import Event


class DummySpan:
Expand Down Expand Up @@ -313,3 +321,221 @@ def test_filter_headers_all_types() -> None:
assert result == expected



# ============================================================================
# Temporal Header Forwarding Tests
# ============================================================================

@pytest.fixture
def mock_temporal_client():
"""Create a mock TemporalClient"""
client = AsyncMock()
client.send_signal = AsyncMock(return_value=None)
return client


@pytest.fixture
def mock_env_vars():
"""Create mock environment variables"""
env_vars = Mock(spec=EnvironmentVariables)
env_vars.WORKFLOW_NAME = "test-workflow"
env_vars.WORKFLOW_TASK_QUEUE = "test-queue"
return env_vars


@pytest.fixture
def temporal_task_service(mock_temporal_client, mock_env_vars):
"""Create TemporalTaskService with mocked client"""
return TemporalTaskService(
temporal_client=mock_temporal_client,
env_vars=mock_env_vars,
)


@pytest.fixture
def sample_agent():
"""Create a sample agent"""
return Agent(
id="agent-123",
name="test-agent",
description="Test agent",
acp_type="agentic",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)


@pytest.fixture
def sample_task():
"""Create a sample task"""
return Task(id="task-456")


@pytest.fixture
def sample_event():
"""Create a sample event"""
return Event(
id="event-789",
agent_id="agent-123",
task_id="task-456",
sequence_id=1,
content=TextContent(author="user", content="Test message")
)


@pytest.mark.asyncio
async def test_temporal_task_service_send_event_with_headers(
temporal_task_service,
mock_temporal_client,
sample_agent,
sample_task,
sample_event
):
"""Test that TemporalTaskService forwards request headers in signal payload"""
# Given
request_headers = {
"x-user-oauth-credentials": "test-oauth-token",
"x-custom-header": "custom-value"
}
request = {"headers": request_headers}

# When
await temporal_task_service.send_event(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=request
)

# Then
mock_temporal_client.send_signal.assert_called_once()
call_args = mock_temporal_client.send_signal.call_args

# Verify the signal was sent to the correct workflow
assert call_args.kwargs["workflow_id"] == sample_task.id
assert call_args.kwargs["signal"] == "receive_event"

# Verify the payload includes the request with headers
payload = call_args.kwargs["payload"]
assert "request" in payload
assert payload["request"] == request
assert payload["request"]["headers"] == request_headers


@pytest.mark.asyncio
async def test_temporal_task_service_send_event_without_headers(
temporal_task_service,
mock_temporal_client,
sample_agent,
sample_task,
sample_event
):
"""Test that TemporalTaskService handles missing request gracefully"""
# When - Send event without request parameter
await temporal_task_service.send_event(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=None
)

# Then
mock_temporal_client.send_signal.assert_called_once()
call_args = mock_temporal_client.send_signal.call_args

# Verify the payload has request as None
payload = call_args.kwargs["payload"]
assert payload["request"] is None


@pytest.mark.asyncio
async def test_temporal_acp_integration_with_request_headers(
mock_temporal_client,
mock_env_vars,
sample_agent,
sample_task,
sample_event
):
"""Test end-to-end integration: TemporalACP -> TemporalTaskService -> TemporalClient signal"""
# Given - Create real TemporalTaskService with mocked client
task_service = TemporalTaskService(
temporal_client=mock_temporal_client,
env_vars=mock_env_vars,
)

# Create TemporalACP with real task service
temporal_acp = TemporalACP(
temporal_address="localhost:7233",
temporal_task_service=task_service,
)
temporal_acp._setup_handlers()

request_headers = {
"x-user-id": "user-123",
"authorization": "Bearer token",
"x-tenant-id": "tenant-456"
}
request = {"headers": request_headers}

# Create SendEventParams as TemporalACP would receive it
params = SendEventParams(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=request
)

# When - Trigger the event handler via the decorated function
# The handler is registered via @temporal_acp.on_task_event_send
# We'll directly call the task service method as the handler does
await task_service.send_event(
agent=params.agent,
task=params.task,
event=params.event,
request=params.request
)

# Then - Verify the temporal client received the signal with request headers
mock_temporal_client.send_signal.assert_called_once()
call_args = mock_temporal_client.send_signal.call_args

# Verify signal payload includes request with headers
payload = call_args.kwargs["payload"]
assert payload["request"] == request
assert payload["request"]["headers"] == request_headers


@pytest.mark.asyncio
async def test_temporal_task_service_preserves_all_header_types(
temporal_task_service,
mock_temporal_client,
sample_agent,
sample_task,
sample_event
):
"""Test that various header types are preserved correctly"""
# Given - Headers with different patterns
request_headers = {
"x-user-oauth-credentials": "oauth-token-12345",
"authorization": "Bearer jwt-token",
"x-tenant-id": "tenant-999",
"x-custom-app-header": "custom-value"
}
request = {"headers": request_headers}

# When
await temporal_task_service.send_event(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=request
)

# Then - Verify all headers are preserved in the signal payload
call_args = mock_temporal_client.send_signal.call_args
payload = call_args.kwargs["payload"]

assert payload["request"]["headers"] == request_headers
# Verify each header individually
for header_name, header_value in request_headers.items():
assert payload["request"]["headers"][header_name] == header_value