Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions app/features/agents/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import functools
import inspect
import json
import os
from collections.abc import Awaitable, Callable
from typing import Any
from typing import Any, cast

import httpx
import structlog
from pydantic_ai import ModelRetry
from pydantic_ai.models import Model
Expand Down Expand Up @@ -62,6 +64,71 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> ToolReturnT:
return wrapper


def _coerce_null_message_content(body: bytes) -> bytes | None:
"""Coerce ``messages[*].content: null`` -> ``""`` in a chat-request body.

Ollama's OpenAI-compatible ``/v1/chat/completions`` rejects any message
whose ``content`` is JSON ``null`` and which carries no ``tool_calls`` with
``400 invalid message content type: <nil>`` — stricter than the real OpenAI
API, which tolerates it. A weak local model can emit a degenerate empty
assistant turn (no text, no tool call); PydanticAI serialises it as
``content: null`` and then *replays* that message on its validation-retry,
so every retry 400s and the whole run dies with a ``FallbackExceptionGroup``.
Coercing ``null`` -> ``""`` keeps the message OpenAI-spec-valid and lets the
retry loop proceed.

Args:
body: The raw outgoing request body bytes.

Returns:
Re-serialised body bytes when a null ``content`` was rewritten, or
``None`` when nothing changed (the common case) so the caller can
forward the original request untouched.
"""
try:
parsed = json.loads(body)
except (ValueError, TypeError):
return None
if not isinstance(parsed, dict):
return None
payload = cast("dict[str, Any]", parsed)
messages = payload.get("messages")
if not isinstance(messages, list):
return None
message_list: list[Any] = messages
changed = False
for message in message_list:
if isinstance(message, dict) and "content" in message and message["content"] is None:
message["content"] = ""
changed = True
if not changed:
return None
return json.dumps(payload).encode("utf-8")


class _OllamaNullContentTransport(httpx.AsyncHTTPTransport):
"""httpx transport that null-content-sanitises outgoing Ollama requests.

See :func:`_coerce_null_message_content` for the Ollama-compat defect this
works around. Applied to the ``OllamaProvider``'s HTTP client so the fix
covers both the streaming and non-streaming agent paths.
"""

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
sanitized = _coerce_null_message_content(request.content)
if sanitized is not None:
headers = dict(request.headers)
headers.pop("content-length", None) # httpx recomputes from the new body
request = httpx.Request(
request.method,
request.url,
headers=headers,
content=sanitized,
extensions=request.extensions,
)
return await super().handle_async_request(request)


def build_agent_model(identifier: str) -> str | Model:
"""Build the PydanticAI ``model`` argument for an agent identifier.

Expand All @@ -85,7 +152,17 @@ def build_agent_model(identifier: str) -> str | Model:
model_name = identifier.split(":", 1)[1]
# CRITICAL: Ollama's OpenAI-compatible base ends in /v1.
base_url = settings.ollama_base_url.rstrip("/") + "/v1"
return OpenAIChatModel(model_name, provider=OllamaProvider(base_url=base_url))
# The null-content sanitiser lives on the HTTP client (see
# _OllamaNullContentTransport). A generous read timeout is required because
# local generation on an 8B model routinely exceeds httpx's 5s default.
http_client = httpx.AsyncClient(
transport=_OllamaNullContentTransport(),
timeout=httpx.Timeout(600.0, connect=10.0),
)
return OpenAIChatModel(
model_name,
provider=OllamaProvider(base_url=base_url, http_client=http_client),
)


def reset_agent_caches() -> None:
Expand Down
87 changes: 85 additions & 2 deletions app/features/agents/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,23 @@ async def chat(
error=str(e),
error_type=type(e).__name__,
)
session.last_activity = datetime.now(UTC)
misbehavior_now = datetime.now(UTC)
session.last_activity = misbehavior_now
# A gated tool may have fired (and recorded a valid approval request)
# before the model misbehaved — surface the Approve card rather than
# discarding it behind the generic error (#344).
salvaged = self._salvage_pending_action(session, deps, misbehavior_now)
await db.flush()
if salvaged is not None:
return ChatResponse(
session_id=session_id,
message=(
"I've prepared an action that needs your approval before "
"I can proceed. Please review the pending request."
),
pending_approval=True,
pending_action=salvaged,
)
return ChatResponse(
session_id=session_id,
message=(
Expand Down Expand Up @@ -650,6 +665,35 @@ async def stream_chat(
error=str(e),
error_type=type(e).__name__,
)
misbehavior_now = datetime.now(UTC)
session.last_activity = misbehavior_now
# A gated tool may have fired (and recorded a valid approval request)
# before the model misbehaved — surface the Approve card rather than
# discarding it behind the generic error (#344).
salvaged = self._salvage_pending_action(session, deps, misbehavior_now)
await db.flush()
if salvaged is not None:
yield StreamEvent(
event_type="approval_required",
data={
"action": salvaged,
"message": "Human approval required before proceeding.",
},
timestamp=misbehavior_now,
)
yield StreamEvent(
event_type="complete",
data={
"message": (
"I've prepared an action that needs your approval before I can proceed."
),
"tokens_used": 0,
"tool_calls_count": deps.tool_call_count,
"pending_approval": True,
},
timestamp=misbehavior_now,
)
return
yield StreamEvent(
event_type="error",
data={
Expand All @@ -660,7 +704,7 @@ async def stream_chat(
"error_type": "model_behavior_error",
"recoverable": True,
},
timestamp=datetime.now(UTC),
timestamp=misbehavior_now,
)
return

Expand Down Expand Up @@ -858,6 +902,45 @@ def _deserialize_messages(
)
return []

def _salvage_pending_action(
self,
session: AgentSession,
deps: AgentDeps,
now: datetime,
) -> PendingAction | None:
"""Persist a gated tool's approval request captured before a misbehaving run.

A gated tool sets ``deps.pending_action`` the moment it fires (#336), but
it does not halt the run. A weak model can ramble past the gate and
exhaust its retry budget, so ``agent.run()`` raises
``UnexpectedModelBehavior`` BEFORE returning and the normal post-run
approval-surfacing path never executes. The gate did fire and the
captured arguments are valid, so surface the approval card instead of
discarding it behind a generic error (issue #344).

Args:
session: The agent session to mutate.
deps: The agent deps that a gated tool may have written to.
now: Timestamp for created_at / expires_at.

Returns:
The formatted :class:`PendingAction` when a gated tool recorded a
request, else ``None`` (the genuine "invalid tool call" case).
"""
if not deps.pending_action:
return None
action_type = str(deps.pending_action.get("action_type", "unknown"))
return self._record_pending_action(
session,
action_type=action_type,
arguments=deps.pending_action.get("arguments") or {},
description=str(
deps.pending_action.get("description")
or f"Agent requested approval for {action_type}"
),
now=now,
)

def _record_pending_action(
self,
session: AgentSession,
Expand Down
135 changes: 135 additions & 0 deletions app/features/agents/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Unit tests for agent base helpers (Ollama-aware model factory)."""

import json
import re
from collections.abc import Iterator
from typing import Any, cast
from unittest.mock import AsyncMock

import httpx
import pytest
from pydantic_ai import ModelRetry
from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart
Expand All @@ -15,6 +17,8 @@
from app.core.config import get_settings
from app.features.agents.agents.base import (
TOOL_USAGE_INSTRUCTIONS,
_coerce_null_message_content,
_OllamaNullContentTransport,
build_agent_model,
build_agent_model_with_fallback,
get_agent_retries,
Expand Down Expand Up @@ -322,3 +326,134 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert captured["output_tools"] == []
assert isinstance(result.output, RAGAnswer)
assert result.output.confidence == "high"


class TestOllamaNullContentSanitizer:
"""The Ollama HTTP client must convert ``content: null`` -> ``""`` (#344).

Ollama's OpenAI-compatible ``/v1/chat/completions`` rejects any message
whose ``content`` is JSON ``null`` and carries no ``tool_calls`` with
``400 invalid message content type: <nil>``. PydanticAI emits that shape for
a degenerate empty assistant turn and then replays it on retry, so without
this coercion every retry 400s and the run dies with ``FallbackExceptionGroup``.
"""

def test_coerce_rewrites_null_content_to_empty_string(self) -> None:
body = json.dumps(
{
"model": "qwen3:8b",
"messages": [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": None},
],
}
).encode("utf-8")

out = _coerce_null_message_content(body)

assert out is not None
payload = json.loads(out)
assert payload["messages"][1]["content"] == ""
# Untouched fields survive the round-trip.
assert payload["messages"][0]["content"] == "hi"
assert payload["model"] == "qwen3:8b"

def test_coerce_rewrites_null_content_even_with_tool_calls(self) -> None:
body = json.dumps(
{
"messages": [
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "c1",
"type": "function",
"function": {"name": "x", "arguments": "{}"},
}
],
}
]
}
).encode("utf-8")

out = _coerce_null_message_content(body)

assert out is not None
payload = json.loads(out)
assert payload["messages"][0]["content"] == ""
assert payload["messages"][0]["tool_calls"][0]["id"] == "c1"

def test_coerce_is_noop_when_no_null_content(self) -> None:
body = json.dumps({"messages": [{"role": "user", "content": "hi"}]}).encode("utf-8")

assert _coerce_null_message_content(body) is None

def test_coerce_ignores_missing_content_key(self) -> None:
# A message with no ``content`` key at all must not be rewritten — only
# an explicit JSON null is the Ollama-rejected shape.
body = json.dumps({"messages": [{"role": "assistant", "tool_calls": []}]}).encode("utf-8")

assert _coerce_null_message_content(body) is None

def test_coerce_handles_non_json_body(self) -> None:
assert _coerce_null_message_content(b"not json at all") is None

def test_coerce_handles_non_dict_payload(self) -> None:
assert _coerce_null_message_content(b"[1, 2, 3]") is None

@pytest.mark.asyncio
async def test_transport_sanitizes_outgoing_request(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The transport rewrites the body and fixes Content-Length before send."""
captured: dict[str, bytes] = {}

async def fake_send(
_self: httpx.AsyncHTTPTransport, request: httpx.Request
) -> httpx.Response:
captured["body"] = request.content
captured["content_length"] = request.headers["content-length"].encode()
return httpx.Response(200, json={"ok": True})

monkeypatch.setattr(httpx.AsyncHTTPTransport, "handle_async_request", fake_send)

transport = _OllamaNullContentTransport()
body = json.dumps({"messages": [{"role": "assistant", "content": None}]}).encode("utf-8")
request = httpx.Request("POST", "http://ollama/v1/chat/completions", content=body)

await transport.handle_async_request(request)

sent = json.loads(captured["body"])
assert sent["messages"][0]["content"] == ""
# Content-Length must match the rewritten body, not the original.
assert int(captured["content_length"]) == len(captured["body"])

@pytest.mark.asyncio
async def test_transport_passthrough_when_nothing_to_sanitize(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
captured: dict[str, bytes] = {}

async def fake_send(
_self: httpx.AsyncHTTPTransport, request: httpx.Request
) -> httpx.Response:
captured["body"] = request.content
return httpx.Response(200, json={"ok": True})

monkeypatch.setattr(httpx.AsyncHTTPTransport, "handle_async_request", fake_send)

transport = _OllamaNullContentTransport()
body = json.dumps({"messages": [{"role": "user", "content": "hi"}]}).encode("utf-8")
request = httpx.Request("POST", "http://ollama/v1/chat/completions", content=body)

await transport.handle_async_request(request)

# Forwarded unchanged.
assert json.loads(captured["body"])["messages"][0]["content"] == "hi"

def test_build_agent_model_returns_openai_chat_model_for_ollama(self) -> None:
# The Ollama branch must hand back a configured OpenAIChatModel (whose
# HTTP client carries the sanitizing transport), not the bare identifier.
model = build_agent_model("ollama:qwen3:8b")
assert isinstance(model, OpenAIChatModel)
Loading