diff --git a/src/agentex/lib/adk/_modules/acp.py b/src/agentex/lib/adk/_modules/acp.py index 390fea43..a7e390bc 100644 --- a/src/agentex/lib/adk/_modules/acp.py +++ b/src/agentex/lib/adk/_modules/acp.py @@ -59,6 +59,7 @@ async def create_task( start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=5), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + request: dict[str, Any] | None = None, ) -> Task: """ Create a new task. @@ -71,6 +72,7 @@ async def create_task( start_to_close_timeout: The start to close timeout for the task. heartbeat_timeout: The heartbeat timeout for the task. retry_policy: The retry policy for the task. + request: Additional request context including headers to forward to the agent. Returns: The task entry. @@ -85,6 +87,7 @@ async def create_task( params=params, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ), response_type=Task, start_to_close_timeout=start_to_close_timeout, @@ -99,6 +102,7 @@ async def create_task( params=params, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ) async def send_event( @@ -112,15 +116,22 @@ async def send_event( start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=5), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + request: dict[str, Any] | None = None, ) -> Event: """ Send an event to a task. Args: task_id: The ID of the task to send the event to. - data: The data to send to the event. + content: The content to send to the event. agent_id: The ID of the agent to send the event to. agent_name: The name of the agent to send the event to. + trace_id: The trace ID for the event. + parent_span_id: The parent span ID for the event. + start_to_close_timeout: The start to close timeout for the event. + heartbeat_timeout: The heartbeat timeout for the event. + retry_policy: The retry policy for the event. + request: Additional request context including headers to forward to the agent. Returns: The event entry. @@ -135,6 +146,7 @@ async def send_event( content=content, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ), response_type=None, start_to_close_timeout=start_to_close_timeout, @@ -149,6 +161,7 @@ async def send_event( content=content, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ) async def send_message( @@ -162,15 +175,22 @@ async def send_message( start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=5), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + request: dict[str, Any] | None = None, ) -> List[TaskMessage]: """ Send a message to a task. Args: - task_id: The ID of the task to send the message to. content: The task message content to send to the task. + task_id: The ID of the task to send the message to. agent_id: The ID of the agent to send the message to. agent_name: The name of the agent to send the message to. + trace_id: The trace ID for the message. + parent_span_id: The parent span ID for the message. + start_to_close_timeout: The start to close timeout for the message. + heartbeat_timeout: The heartbeat timeout for the message. + retry_policy: The retry policy for the message. + request: Additional request context including headers to forward to the agent. Returns: The message entry. @@ -185,6 +205,7 @@ async def send_message( content=content, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ), response_type=TaskMessage, start_to_close_timeout=start_to_close_timeout, @@ -199,6 +220,7 @@ async def send_message( content=content, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ) async def cancel_task( @@ -212,6 +234,7 @@ async def cancel_task( start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=5), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + request: dict[str, Any] | None = None, ) -> Task: """ Cancel a task by sending cancel request to the agent that owns the task. @@ -226,6 +249,7 @@ async def cancel_task( start_to_close_timeout: The start to close timeout for the task. heartbeat_timeout: The heartbeat timeout for the task. retry_policy: The retry policy for the task. + request: Additional request context including headers to forward to the agent. Returns: The task entry. @@ -244,6 +268,7 @@ async def cancel_task( agent_name=agent_name, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ), response_type=None, start_to_close_timeout=start_to_close_timeout, @@ -258,4 +283,5 @@ async def cancel_task( agent_name=agent_name, trace_id=trace_id, parent_span_id=parent_span_id, + request=request, ) diff --git a/src/agentex/lib/core/services/adk/acp/acp.py b/src/agentex/lib/core/services/adk/acp/acp.py index ca56cec8..e07a6548 100644 --- a/src/agentex/lib/core/services/adk/acp/acp.py +++ b/src/agentex/lib/core/services/adk/acp/acp.py @@ -9,6 +9,10 @@ from agentex.types.task_message import TaskMessage from agentex.types.task_message_content import TaskMessageContent from agentex.types.task_message_content_param import TaskMessageContentParam +from agentex.types.agent_rpc_params import ( + ParamsCancelTaskRequest as RpcParamsCancelTaskRequest, + ParamsSendEventRequest as RpcParamsSendEventRequest, +) logger = make_logger(__name__) @@ -30,6 +34,7 @@ async def task_create( params: dict[str, Any] | None = None, trace_id: str | None = None, parent_span_id: str | None = None, + request: dict[str, Any] | None = None, ) -> Task: trace = self._tracer.trace(trace_id=trace_id) async with trace.span( @@ -43,6 +48,10 @@ async def task_create( }, ) as span: heartbeat_if_in_workflow("task create") + + # Extract headers from request; pass-through to agent + extra_headers = request.get("headers") if request else None + if agent_name: json_rpc_response = await self._agentex_client.agents.rpc_by_name( agent_name=agent_name, @@ -51,6 +60,7 @@ async def task_create( "name": name, "params": params, }, + extra_headers=extra_headers, ) elif agent_id: json_rpc_response = await self._agentex_client.agents.rpc( @@ -60,6 +70,7 @@ async def task_create( "name": name, "params": params, }, + extra_headers=extra_headers, ) else: raise ValueError("Either agent_name or agent_id must be provided") @@ -78,6 +89,7 @@ async def message_send( task_name: str | None = None, trace_id: str | None = None, parent_span_id: str | None = None, + request: dict[str, Any] | None = None, ) -> List[TaskMessage]: trace = self._tracer.trace(trace_id=trace_id) async with trace.span( @@ -92,6 +104,10 @@ async def message_send( }, ) as span: heartbeat_if_in_workflow("message send") + + # Extract headers from request; pass-through to agent + extra_headers = request.get("headers") if request else None + if agent_name: json_rpc_response = await self._agentex_client.agents.rpc_by_name( agent_name=agent_name, @@ -101,6 +117,7 @@ async def message_send( "content": cast(TaskMessageContentParam, content.model_dump()), "stream": False, }, + extra_headers=extra_headers, ) elif agent_id: json_rpc_response = await self._agentex_client.agents.rpc( @@ -111,12 +128,13 @@ async def message_send( "content": cast(TaskMessageContentParam, content.model_dump()), "stream": False, }, + extra_headers=extra_headers, ) else: raise ValueError("Either agent_name or agent_id must be provided") task_messages: List[TaskMessage] = [] - logger.info(f"json_rpc_response: {json_rpc_response}") + logger.info("json_rpc_response: %s", json_rpc_response) if isinstance(json_rpc_response.result, list): for message in json_rpc_response.result: task_message = TaskMessage.model_validate(message) @@ -137,6 +155,7 @@ async def event_send( task_name: str | None = None, trace_id: str | None = None, parent_span_id: str | None = None, + request: dict[str, Any] | None = None, ) -> Event: trace = self._tracer.trace(trace_id=trace_id) async with trace.span( @@ -146,27 +165,33 @@ async def event_send( "agent_id": agent_id, "agent_name": agent_name, "task_id": task_id, + "task_name": task_name, "content": content, }, ) as span: heartbeat_if_in_workflow("event send") + + # Extract headers from request; pass-through to agent + extra_headers = request.get("headers") if request else None + + rpc_event_params: RpcParamsSendEventRequest = { + "task_id": task_id, + "task_name": task_name, + "content": cast(TaskMessageContentParam, content.model_dump()), + } if agent_name: json_rpc_response = await self._agentex_client.agents.rpc_by_name( agent_name=agent_name, method="event/send", - params={ - "task_id": task_id, - "content": cast(TaskMessageContentParam, content.model_dump()), - }, + params=rpc_event_params, + extra_headers=extra_headers, ) elif agent_id: json_rpc_response = await self._agentex_client.agents.rpc( agent_id=agent_id, method="event/send", - params={ - "task_id": task_id, - "content": cast(TaskMessageContentParam, content.model_dump()), - }, + params=rpc_event_params, + extra_headers=extra_headers, ) else: raise ValueError("Either agent_name or agent_id must be provided") @@ -184,7 +209,27 @@ async def task_cancel( agent_name: str | None = None, trace_id: str | None = None, parent_span_id: str | None = None, - ) -> Task: + request: dict[str, Any] | None = None, + ) -> Task: + """ + Cancel a task by sending cancel request to the agent that owns the task. + + Args: + task_id: ID of the task to cancel (passed to agent in params) + task_name: Name of the task to cancel (passed to agent in params) + agent_id: ID of the agent that owns the task + agent_name: Name of the agent that owns the task + trace_id: Trace ID for tracing + parent_span_id: Parent span ID for tracing + request: Additional request context including headers to forward to the agent + + Returns: + Task entry representing the cancelled task + + Raises: + ValueError: If neither agent_name nor agent_id is provided, + or if neither task_name nor task_id is provided + """ # Require agent identification if not agent_name and not agent_id: raise ValueError("Either agent_name or agent_id must be provided to identify the agent that owns the task") @@ -192,7 +237,6 @@ async def task_cancel( # Require task identification if not task_name and not task_id: raise ValueError("Either task_name or task_id must be provided to identify the task to cancel") - trace = self._tracer.trace(trace_id=trace_id) async with trace.span( parent_id=parent_span_id, @@ -206,8 +250,11 @@ async def task_cancel( ) as span: heartbeat_if_in_workflow("task cancel") + # Extract headers from request; pass-through to agent + extra_headers = request.get("headers") if request else None + # Build params for the agent (task identification) - params = {} + params: RpcParamsCancelTaskRequest = {} if task_id: params["task_id"] = task_id if task_name: @@ -219,12 +266,15 @@ async def task_cancel( agent_name=agent_name, method="task/cancel", params=params, + extra_headers=extra_headers, ) else: # agent_id is provided (validated above) + assert agent_id is not None json_rpc_response = await self._agentex_client.agents.rpc( agent_id=agent_id, method="task/cancel", params=params, + extra_headers=extra_headers, ) task_entry = Task.model_validate(json_rpc_response.result) diff --git a/src/agentex/lib/core/temporal/activities/adk/acp/acp_activities.py b/src/agentex/lib/core/temporal/activities/adk/acp/acp_activities.py index ecdbd5cf..be81e7ab 100644 --- a/src/agentex/lib/core/temporal/activities/adk/acp/acp_activities.py +++ b/src/agentex/lib/core/temporal/activities/adk/acp/acp_activities.py @@ -26,6 +26,7 @@ class TaskCreateParams(BaseModelWithTraceParams): agent_id: str | None = None agent_name: str | None = None params: dict[str, Any] | None = None + request: dict[str, Any] | None = None class MessageSendParams(BaseModelWithTraceParams): @@ -33,6 +34,7 @@ class MessageSendParams(BaseModelWithTraceParams): agent_name: str | None = None task_id: str | None = None content: TaskMessageContent + request: dict[str, Any] | None = None class EventSendParams(BaseModelWithTraceParams): @@ -40,6 +42,7 @@ class EventSendParams(BaseModelWithTraceParams): agent_name: str | None = None task_id: str | None = None content: TaskMessageContent + request: dict[str, Any] | None = None class TaskCancelParams(BaseModelWithTraceParams): @@ -47,6 +50,7 @@ class TaskCancelParams(BaseModelWithTraceParams): task_name: str | None = None agent_id: str | None = None agent_name: str | None = None + request: dict[str, Any] | None = None class ACPActivities: @@ -60,6 +64,9 @@ async def task_create(self, params: TaskCreateParams) -> Task: agent_id=params.agent_id, agent_name=params.agent_name, params=params.params, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + request=params.request, ) @activity.defn(name=ACPActivityName.MESSAGE_SEND) @@ -69,6 +76,9 @@ async def message_send(self, params: MessageSendParams) -> List[TaskMessage]: agent_name=params.agent_name, task_id=params.task_id, content=params.content, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + request=params.request, ) @activity.defn(name=ACPActivityName.EVENT_SEND) @@ -78,6 +88,9 @@ async def event_send(self, params: EventSendParams) -> Event: agent_name=params.agent_name, task_id=params.task_id, content=params.content, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + request=params.request, ) @activity.defn(name=ACPActivityName.TASK_CANCEL) @@ -89,4 +102,5 @@ async def task_cancel(self, params: TaskCancelParams) -> Task: agent_name=params.agent_name, trace_id=params.trace_id, parent_span_id=params.parent_span_id, + request=params.request, ) diff --git a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py index 8ae90a45..c26fe7bd 100644 --- a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py +++ b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py @@ -26,6 +26,10 @@ from agentex.lib.utils.logging import make_logger from agentex.lib.utils.model_utils import BaseModel from agentex.lib.utils.registration import register_agent +from agentex.lib.sdk.fastacp.base.constants import ( + FASTACP_HEADER_SKIP_EXACT, + FASTACP_HEADER_SKIP_PREFIXES, +) logger = make_logger(__name__) @@ -128,9 +132,23 @@ async def _handle_jsonrpc(self, request: Request): ), ) - # Parse params into appropriate model based on method + # Extract application headers, excluding sensitive/transport headers per FASTACP_* rules + # Forward filtered headers via params.request.headers to agent handlers + custom_headers = { + key: value + for key, value in request.headers.items() + if key.lower() not in FASTACP_HEADER_SKIP_EXACT + and not any(key.lower().startswith(p) for p in FASTACP_HEADER_SKIP_PREFIXES) + } + + # Parse params into appropriate model based on method and include headers params_model = PARAMS_MODEL_BY_METHOD[method] - params = params_model.model_validate(rpc_request.params) + params_data = dict(rpc_request.params) if rpc_request.params else {} + + # Add custom headers to the request structure if any headers were provided + if custom_headers: + params_data["request"] = {"headers": custom_headers} + params = params_model.model_validate(params_data) if method in RPC_SYNC_METHODS: handler = self._handlers[method] diff --git a/src/agentex/lib/sdk/fastacp/base/constants.py b/src/agentex/lib/sdk/fastacp/base/constants.py new file mode 100644 index 00000000..ed83ffd3 --- /dev/null +++ b/src/agentex/lib/sdk/fastacp/base/constants.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +# Header filtering rules for FastACP server + +# Prefixes to skip (case-insensitive beginswith checks) +FASTACP_HEADER_SKIP_PREFIXES: tuple[str, ...] = ( + "content-", + "host", + "user-agent", + "x-forwarded-", + "sec-", +) + +# Exact header names to skip (case-insensitive matching done by lowercasing keys) +FASTACP_HEADER_SKIP_EXACT: set[str] = { + "x-agent-api-key", + "connection", + "accept-encoding", + "cookie", + "content-length", + "transfer-encoding", +} + + diff --git a/src/agentex/lib/types/acp.py b/src/agentex/lib/types/acp.py index 4ec008c8..d93e2894 100644 --- a/src/agentex/lib/types/acp.py +++ b/src/agentex/lib/types/acp.py @@ -25,6 +25,7 @@ class CreateTaskParams(BaseModel): agent: The agent that the task was sent to. task: The task to be created. params: The parameters for the task as inputted by the user. + request: Additional request context including headers forwarded to this agent. """ agent: Agent = Field(..., description="The agent that the task was sent to") @@ -33,6 +34,10 @@ class CreateTaskParams(BaseModel): None, description="The parameters for the task as inputted by the user", ) + request: dict[str, Any] | None = Field( + default=None, + description="Additional request context including headers forwarded to this agent", + ) class SendMessageParams(BaseModel): @@ -43,6 +48,7 @@ class SendMessageParams(BaseModel): task: The task that the message was sent to. content: The message that was sent to the agent. stream: Whether to stream the message back to the agentex server from the agent. + request: Additional request context including headers forwarded to this agent. """ agent: Agent = Field(..., description="The agent that the message was sent to") @@ -54,6 +60,10 @@ class SendMessageParams(BaseModel): False, description="Whether to stream the message back to the agentex server from the agent", ) + request: dict[str, Any] | None = Field( + default=None, + description="Additional request context including headers forwarded to this agent", + ) class SendEventParams(BaseModel): @@ -63,11 +73,16 @@ class SendEventParams(BaseModel): agent: The agent that the event was sent to. task: The task that the message was sent to. event: The event that was sent to the agent. + request: Additional request context including headers forwarded to this agent. """ agent: Agent = Field(..., description="The agent that the event was sent to") task: Task = Field(..., description="The task that the message was sent to") event: Event = Field(..., description="The event that was sent to the agent") + request: dict[str, Any] | None = Field( + default=None, + description="Additional request context including headers forwarded to this agent", + ) class CancelTaskParams(BaseModel): @@ -76,10 +91,15 @@ class CancelTaskParams(BaseModel): Attributes: agent: The agent that the task was sent to. task: The task that was cancelled. + request: Additional request context including headers forwarded to this agent. """ agent: Agent = Field(..., description="The agent that the task was sent to") task: Task = Field(..., description="The task that was cancelled") + request: dict[str, Any] | None = Field( + default=None, + description="Additional request context including headers forwarded to this agent", + ) RPC_SYNC_METHODS = [ diff --git a/tests/test_header_forwarding.py b/tests/test_header_forwarding.py new file mode 100644 index 00000000..6e5b242f --- /dev/null +++ b/tests/test_header_forwarding.py @@ -0,0 +1,313 @@ +# ruff: noqa: I001 +from typing import Any +import sys +import types + +import pytest +from fastapi.testclient import TestClient + +"""Header forwarding tests consolidated. + +We stub tracing modules to avoid circular imports when importing ACPService. +""" + +# Stub tracing modules before importing ACPService +tracer_stub = types.ModuleType("agentex.lib.core.tracing.tracer") + +class _StubSpan: + async def __aenter__(self): + return self + async def __aexit__(self, exc_type, exc, tb): + return False + +class _StubTrace: + def span(self, **kwargs: Any) -> _StubSpan: # type: ignore[name-defined] + return _StubSpan() + +class _StubAsyncTracer: + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + def trace(self, trace_id: str | None = None) -> _StubTrace: # type: ignore[name-defined] + return _StubTrace() + +class _StubTracer(_StubAsyncTracer): + pass +tracer_stub.AsyncTracer = _StubAsyncTracer +tracer_stub.Tracer = _StubTracer +sys.modules["agentex.lib.core.tracing.tracer"] = tracer_stub + +tracing_pkg_stub = types.ModuleType("agentex.lib.core.tracing") +tracing_pkg_stub.AsyncTracer = _StubAsyncTracer +tracing_pkg_stub.Tracer = _StubTracer +sys.modules["agentex.lib.core.tracing"] = tracing_pkg_stub + +from agentex.lib.core.services.adk.acp.acp import ACPService +from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer +from agentex.lib.types.acp import RPCMethod, SendMessageParams +from agentex.types.task_message_content import TextContent + + +class DummySpan: + def __init__(self, **_kwargs: Any) -> None: + self.output = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class DummyTrace: + def span(self, **kwargs: Any) -> DummySpan: + return DummySpan(**kwargs) + + +class DummyTracer: + def trace(self, trace_id: str | None = None) -> DummyTrace: + return DummyTrace() + + +class DummyAgents: + async def rpc_by_name(self, *args: Any, **kwargs: Any) -> Any: + # Support both positional and keyword agent name, and both params/_params + method = kwargs.get("method") + extra_headers = kwargs.get("extra_headers") + # Ensure headers are forwarded as-is + assert extra_headers == {"x-user": "a", "authorization": "b"} + # Minimal response object with .result + if method == "task/create": + return type("R", (), {"result": {"id": "t1"}})() + if method == "message/send": + # include required task_id for TaskMessage model + return type("R", (), {"result": {"id": "m1", "task_id": "t1", "content": {"type": "text", "author": "user", "content": "ok"}}})() + if method == "event/send": + # include required fields for Event model + return type("R", (), {"result": {"id": "e1", "agent_id": "a1", "task_id": "t1", "sequence_id": 1}})() + if method == "task/cancel": + return type("R", (), {"result": {"id": "t1"}})() + raise AssertionError("Unexpected method") + + +class DummyClient: + def __init__(self) -> None: + self.agents = DummyAgents() + + +@pytest.mark.asyncio +async def test_header_forwarding() -> None: + client = DummyClient() + svc = ACPService(agentex_client=client, tracer=DummyTracer()) # type: ignore[arg-type] + + # Create task + task = await svc.task_create(agent_name="x", request={"headers": {"x-user": "a", "authorization": "b"}}) + assert task.id == "t1" + + # Send message + msgs = await svc.message_send( + agent_name="x", + task_id="t1", + content=TextContent(author="user", content="hi"), + request={"headers": {"x-user": "a", "authorization": "b"}}, + ) + assert len(msgs) == 1 + + # Send event + evt = await svc.event_send( + agent_name="x", + task_id="t1", + content=TextContent(author="user", content="hi"), + request={"headers": {"x-user": "a", "authorization": "b"}}, + ) + assert evt.id == "e1" + + # Cancel + task2 = await svc.task_cancel(agent_name="x", task_id="t1", request={"headers": {"x-user": "a", "authorization": "b"}}) + assert task2.id == "t1" + + +class TestServer(BaseACPServer): + __test__ = False + def _setup_handlers(self): + @self.on_message_send + async def handler(params: SendMessageParams): + headers = (params.request or {}).get("headers", {}) + assert "x-agent-api-key" not in headers + assert headers.get("x-user") == "a" + return TextContent(author="assistant", content="ok") + + +def test_excludes_agent_api_key_header(): + app = TestServer.create() + client = TestClient(app) + req = { + "jsonrpc": "2.0", + "method": RPCMethod.MESSAGE_SEND.value, + "params": { + "agent": {"id": "a1", "name": "n1", "description": "d", "acp_type": "sync"}, + "task": {"id": "t1"}, + "content": {"type": "text", "author": "user", "content": "hi"}, + "stream": False, + }, + "id": 1, + } + r = client.post("/api", json=req, headers={"x-user": "a", "x-agent-api-key": "secret"}) + assert r.status_code == 200 + + +def filter_headers_standalone( + headers: dict[str, str] | None, + allowlist: list[str] | None +) -> dict[str, str]: + """Standalone header filtering function matching the production implementation.""" + if not headers: + return {} + + # Pass-through behavior: if no allowlist, forward all headers + if allowlist is None: + return headers + + # Apply filtering based on allowlist + if not allowlist: + return {} + + import fnmatch + filtered = {} + for header_name, header_value in headers.items(): + # Check against allowlist patterns (case-insensitive) + header_allowed = False + for pattern in allowlist: + if fnmatch.fnmatch(header_name.lower(), pattern.lower()): + header_allowed = True + break + + if header_allowed: + filtered[header_name] = header_value + + return filtered + + +def test_filter_headers_no_headers() -> None: + allowlist = ["x-user-email"] + result = filter_headers_standalone(None, allowlist) + assert result == {} + + result = filter_headers_standalone({}, allowlist) + assert result == {} + + +def test_filter_headers_pass_through_by_default() -> None: + headers = { + "x-user-email": "test@example.com", + "x-admin-token": "secret", + "authorization": "Bearer token", + "x-custom-header": "value" + } + result = filter_headers_standalone(headers, None) + assert result == headers + + +def test_filter_headers_empty_allowlist() -> None: + allowlist: list[str] = [] + headers = {"x-user-email": "test@example.com", "x-admin-token": "secret"} + result = filter_headers_standalone(headers, allowlist) + assert result == {} + + +def test_filter_headers_allowed_headers() -> None: + allowlist = ["x-user-email", "x-tenant-id"] + headers = { + "x-user-email": "test@example.com", + "x-tenant-id": "tenant123", + "x-admin-token": "secret", + "content-type": "application/json" + } + result = filter_headers_standalone(headers, allowlist) + expected = { + "x-user-email": "test@example.com", + "x-tenant-id": "tenant123" + } + assert result == expected + + +def test_filter_headers_case_insensitive_patterns() -> None: + allowlist = ["X-User-Email", "x-tenant-*"] + headers = { + "x-user-email": "test@example.com", + "X-TENANT-ID": "tenant123", + "x-tenant-name": "acme", + "x-admin-token": "secret" + } + result = filter_headers_standalone(headers, allowlist) + expected = { + "x-user-email": "test@example.com", + "X-TENANT-ID": "tenant123", + "x-tenant-name": "acme" + } + assert result == expected + + +def test_filter_headers_wildcard_patterns() -> None: + allowlist = ["x-user-*", "authorization"] + headers = { + "x-user-id": "123", + "x-user-email": "test@example.com", + "x-user-role": "admin", + "authorization": "Bearer token", + "x-system-info": "blocked", + "content-type": "application/json" + } + result = filter_headers_standalone(headers, allowlist) + expected = { + "x-user-id": "123", + "x-user-email": "test@example.com", + "x-user-role": "admin", + "authorization": "Bearer token" + } + assert result == expected + + +def test_filter_headers_complex_patterns() -> None: + allowlist = ["x-tenant-*", "x-user-[abc]*", "auth*"] + headers = { + "x-tenant-id": "tenant1", + "x-tenant-name": "acme", + "x-user-admin": "true", + "x-user-beta": "false", + "x-user-delta": "test", + "authorization": "Bearer x", + "authenticate": "digest", + "content-type": "json", + } + result = filter_headers_standalone(headers, allowlist) + expected = { + "x-tenant-id": "tenant1", + "x-tenant-name": "acme", + "x-user-admin": "true", + "x-user-beta": "false", + "authorization": "Bearer x", + "authenticate": "digest" + } + assert result == expected + + +def test_filter_headers_all_types() -> None: + allowlist = ["authorization", "accept-language", "custom-*"] + headers = { + "authorization": "Bearer token", + "accept-language": "en-US", + "custom-header": "value", + "custom-auth": "token", + "content-type": "application/json", + "x-blocked": "value" + } + result = filter_headers_standalone(headers, allowlist) + expected = { + "authorization": "Bearer token", + "accept-language": "en-US", + "custom-header": "value", + "custom-auth": "token" + } + assert result == expected + +