Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ try:
except UnexpectedModelBehavior as e:
print(e) # (1)!
"""
Safety settings triggered, body:
Content filter 'SAFETY' triggered, body:
<safety settings details>
"""
```
Expand Down
40 changes: 26 additions & 14 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
GoogleFinishReason.IMAGE_PROHIBITED_CONTENT: 'content_filter',
GoogleFinishReason.NO_IMAGE: 'error',
}


Expand Down Expand Up @@ -453,23 +455,28 @@ async def _build_content_and_config(
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
if not response.candidates:
raise UnexpectedModelBehavior('Expected at least one candidate in Gemini response') # pragma: no cover

candidate = response.candidates[0]
if candidate.content is None or candidate.content.parts is None:
if candidate.finish_reason == 'SAFETY':
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
else:
raise UnexpectedModelBehavior(
'Content field missing from Gemini response', str(response)
) # pragma: no cover
parts = candidate.content.parts or []

vendor_id = response.response_id
vendor_details: dict[str, Any] | None = None
finish_reason: FinishReason | None = None
if raw_finish_reason := candidate.finish_reason: # pragma: no branch
raw_finish_reason = candidate.finish_reason
if raw_finish_reason: # pragma: no branch
vendor_details = {'finish_reason': raw_finish_reason.value}
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

if candidate.content is None or candidate.content.parts is None:
if finish_reason == 'content_filter' and raw_finish_reason:
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
)
else:
raise UnexpectedModelBehavior(
'Content field missing from Gemini response', response.model_dump_json()
) # pragma: no cover
parts = candidate.content.parts or []

usage = _metadata_as_usage(response)
return _process_response_from_parts(
parts,
Expand Down Expand Up @@ -623,7 +630,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if chunk.response_id: # pragma: no branch
self.provider_response_id = chunk.response_id

if raw_finish_reason := candidate.finish_reason:
raw_finish_reason = candidate.finish_reason
if raw_finish_reason:
self.provider_details = {'finish_reason': raw_finish_reason.value}
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

Expand All @@ -641,13 +649,17 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# )

if candidate.content is None or candidate.content.parts is None:
if candidate.finish_reason == 'STOP': # pragma: no cover
if self.finish_reason == 'stop': # pragma: no cover
# Normal completion - skip this chunk
continue
elif candidate.finish_reason == 'SAFETY': # pragma: no cover
raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
)
else: # pragma: no cover
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
raise UnexpectedModelBehavior(
'Content field missing from streaming Gemini response', chunk.model_dump_json()
)

parts = candidate.content.parts
if not parts:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p
)
agent = Agent(m, instructions='You hate the world!', model_settings=settings)

with pytest.raises(UnexpectedModelBehavior, match='Safety settings triggered'):
with pytest.raises(UnexpectedModelBehavior, match="Content filter 'SAFETY' triggered"):
await agent.run('Tell me a joke about a Brazilians.')


Expand Down
2 changes: 1 addition & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ async def model_logic( # noqa: C901
]
)
elif m.content.startswith('Write a list of 5 very rude things that I might say'):
raise UnexpectedModelBehavior('Safety settings triggered', body='<safety settings details>')
raise UnexpectedModelBehavior("Content filter 'SAFETY' triggered", body='<safety settings details>')
elif m.content.startswith('<user>\n <name>John Doe</name>'):
return ModelResponse(
parts=[ToolCallPart(tool_name='final_result_EmailOk', args={}, tool_call_id='pyd_ai_tool_call_id')]
Expand Down