Skip to content
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,13 @@ class ModelResponse:
kind: Literal['response'] = 'response'
"""Message type identifier, this is available on all parts as a discriminator."""

vendor_details: dict[str, Any] | None = field(default=None, repr=False)
"""Additional vendor-specific details in a serializable format.

This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
"""

def otel_events(self) -> list[Event]:
"""Return OpenTelemetry events for the response."""
result: list[Event] = []
Expand Down
35 changes: 34 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class OpenAIModelSettings(ModelSettings, total=False):
result in faster responses and fewer tokens used on reasoning in a response.
"""

openai_logprobs: bool
"""Include log probabilities in the response."""

openai_top_logprobs: int
"""Include log probabilities of the top n tokens in the response."""

openai_user: str
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.

Expand Down Expand Up @@ -287,6 +293,8 @@ async def _completions_create(
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
user=model_settings.get('openai_user', NOT_GIVEN),
extra_headers=extra_headers,
extra_body=model_settings.get('extra_body'),
Expand All @@ -301,12 +309,37 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
choice = response.choices[0]
items: list[ModelResponsePart] = []
vendor_details: dict[str, Any] | None = None

# Add logprobs to vendor_details if available
if choice.logprobs is not None and choice.logprobs.content:
# Convert logprobs to a serializable format
vendor_details = {
'logprobs': [
{
'token': lp.token,
'bytes': lp.bytes,
'logprob': lp.logprob,
'top_logprobs': [
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
],
}
for lp in choice.logprobs.content
],
}

if choice.message.content is not None:
items.append(TextPart(choice.message.content))
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
return ModelResponse(
items,
usage=_map_usage(response),
model_name=response.model,
timestamp=timestamp,
vendor_details=vendor_details,
)

async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
Expand Down
56 changes: 52 additions & 4 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
with try_import() as imports_successful:
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI
from openai.types import chat
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
from openai.types.chat.chat_completion_chunk import (
Choice as ChunkChoice,
ChoiceDelta,
Expand All @@ -49,6 +49,7 @@
)
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
from openai.types.completion_usage import CompletionUsage, PromptTokensDetails

from pydantic_ai.models.openai import (
Expand Down Expand Up @@ -129,10 +130,15 @@ def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str
raise RuntimeError('Not a MockOpenAI instance')


def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion:
def completion_message(
message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None
) -> chat.ChatCompletion:
choices = [Choice(finish_reason='stop', index=0, message=message)]
if logprobs:
choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)]
return chat.ChatCompletion(
id='123',
choices=[Choice(finish_reason='stop', index=0, message=message)],
choices=choices,
created=1704067200, # 2024-01-01
model='gpt-4o-123',
object='chat.completion',
Expand All @@ -141,7 +147,9 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage


async def test_request_simple_success(allow_model_requests: None):
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
c = completion_message(
ChatCompletionMessage(content='world', role='assistant'),
)
mock_client = MockOpenAI.create_mock(c)
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
agent = Agent(m)
Expand Down Expand Up @@ -1543,3 +1551,43 @@ async def get_temperature(city: str) -> float:
),
]
)


@pytest.mark.vcr()
async def test_openai_instructions_with_logprobs(allow_model_requests: None):
# Create a mock response with logprobs
c = completion_message(
ChatCompletionMessage(content='world', role='assistant'),
logprobs=ChoiceLogprobs(
content=[
ChatCompletionTokenLogprob(
token='world', logprob=-0.6931, top_logprobs=[], bytes=[119, 111, 114, 108, 100]
)
],
),
)

mock_client = MockOpenAI.create_mock(c)
m = OpenAIModel(
'gpt-4o',
provider=OpenAIProvider(openai_client=mock_client),
)
agent = Agent(
m,
instructions='You are a helpful assistant.',
)
result = await agent.run(
'What is the capital of Minas Gerais?',
model_settings=OpenAIModelSettings(openai_logprobs=True),
)
messages = result.all_messages()
response = cast(Any, messages[1])
assert response.vendor_details is not None
assert response.vendor_details['logprobs'] == [
{
'token': 'world',
'logprob': -0.6931,
'bytes': [119, 111, 114, 108, 100],
'top_logprobs': [],
}
]
1 change: 1 addition & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,7 @@ def test_binary_content_all_messages_json():
'model_name': 'test',
'timestamp': IsStr(),
'kind': 'response',
'vendor_details': None,
},
]
)
Expand Down