Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

161 changes: 161 additions & 0 deletions tests/models/mock_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations as _annotations

from collections.abc import Sequence
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Union, cast

from ..conftest import raise_if_exception, try_import
from .mock_async_stream import MockAsyncStream

with try_import() as imports_successful:
from openai import NOT_GIVEN, AsyncOpenAI
from openai.types import chat, responses
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
from openai.types.responses.response import ResponseUsage
from openai.types.responses.response_output_item import ResponseOutputItem

# note: we use Union here so that casting works with Python 3.9
MockChatCompletion = Union[chat.ChatCompletion, Exception]
MockChatCompletionChunk = Union[chat.ChatCompletionChunk, Exception]
MockResponse = Union[responses.Response, Exception]
MockResponseStreamEvent = Union[responses.ResponseStreamEvent, Exception]


@dataclass
class MockOpenAI:
completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None
index: int = 0
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)

@cached_property
def chat(self) -> Any:
chat_completions = type('Completions', (), {'create': self.chat_completions_create})
return type('Chat', (), {'completions': chat_completions})

@classmethod
def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncOpenAI:
return cast(AsyncOpenAI, cls(completions=completions))

@classmethod
def create_mock_stream(
cls,
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]],
) -> AsyncOpenAI:
return cast(AsyncOpenAI, cls(stream=stream))

async def chat_completions_create( # pragma: lax no cover
self, *_args: Any, stream: bool = False, **kwargs: Any
) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]:
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})

if stream:
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
if isinstance(self.stream[0], Sequence):
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream[self.index])))
else:
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream)))
else:
assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided'
if isinstance(self.completions, Sequence):
raise_if_exception(self.completions[self.index])
response = cast(chat.ChatCompletion, self.completions[self.index])
else:
raise_if_exception(self.completions)
response = cast(chat.ChatCompletion, self.completions)
self.index += 1
return response


def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]:
if isinstance(async_open_ai, MockOpenAI):
return async_open_ai.chat_completion_kwargs
else: # pragma: no cover
raise RuntimeError('Not a MockOpenAI instance')


def completion_message(
message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None
) -> chat.ChatCompletion:
choices = [Choice(finish_reason='stop', index=0, message=message)]
if logprobs:
choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)]
return chat.ChatCompletion(
id='123',
choices=choices,
created=1704067200, # 2024-01-01
model='gpt-4o-123',
object='chat.completion',
usage=usage,
)


@dataclass
class MockOpenAIResponses:
response: MockResponse | Sequence[MockResponse] | None = None
stream: Sequence[MockResponseStreamEvent] | Sequence[Sequence[MockResponseStreamEvent]] | None = None
index: int = 0
response_kwargs: list[dict[str, Any]] = field(default_factory=list)

@cached_property
def responses(self) -> Any:
return type('Responses', (), {'create': self.responses_create})

@classmethod
def create_mock(cls, responses: MockResponse | Sequence[MockResponse]) -> AsyncOpenAI:
return cast(AsyncOpenAI, cls(response=responses))

@classmethod
def create_mock_stream(
cls,
stream: Sequence[MockResponseStreamEvent] | Sequence[Sequence[MockResponseStreamEvent]],
) -> AsyncOpenAI:
return cast(AsyncOpenAI, cls(stream=stream)) # pragma: lax no cover

async def responses_create( # pragma: lax no cover
self, *_args: Any, stream: bool = False, **kwargs: Any
) -> responses.Response | MockAsyncStream[MockResponseStreamEvent]:
self.response_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})

if stream:
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
if isinstance(self.stream[0], Sequence):
response = MockAsyncStream(iter(cast(list[MockResponseStreamEvent], self.stream[self.index])))
else:
response = MockAsyncStream(iter(cast(list[MockResponseStreamEvent], self.stream)))
else:
assert self.response is not None, 'you can only used `stream=False` if `response` are provided'
if isinstance(self.response, Sequence):
raise_if_exception(self.response[self.index])
response = cast(responses.Response, self.response[self.index])
else:
raise_if_exception(self.response)
response = cast(responses.Response, self.response)
self.index += 1
return response


def get_mock_responses_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]:
if isinstance(async_open_ai, MockOpenAIResponses): # pragma: lax no cover
return async_open_ai.response_kwargs
else: # pragma: no cover
raise RuntimeError('Not a MockOpenAIResponses instance')


def response_message(
output_items: Sequence[ResponseOutputItem], *, usage: ResponseUsage | None = None
) -> responses.Response:
return responses.Response(
id='123',
model='gpt-4o-123',
object='response',
created_at=1704067200, # 2024-01-01
output=list(output_items),
parallel_tool_calls=True,
tool_choice='auto',
tools=[],
usage=usage,
)
81 changes: 5 additions & 76 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations as _annotations

import json
from collections.abc import Sequence
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from functools import cached_property
from typing import Annotated, Any, Callable, Literal, Union, cast

import httpx
Expand Down Expand Up @@ -47,13 +45,13 @@
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.usage import RequestUsage

from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, TestEnv, raise_if_exception, try_import
from .mock_async_stream import MockAsyncStream
from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, TestEnv, try_import
from .mock_openai import MockOpenAI, completion_message, get_mock_chat_completion_kwargs

with try_import() as imports_successful:
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI
from openai import APIStatusError, AsyncOpenAI
from openai.types import chat
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
from openai.types.chat.chat_completion import ChoiceLogprobs
from openai.types.chat.chat_completion_chunk import (
Choice as ChunkChoice,
ChoiceDelta,
Expand Down Expand Up @@ -98,75 +96,6 @@ def test_init():
assert m.model_name == 'gpt-4o'


@dataclass
class MockOpenAI:
completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None
index: int = 0
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)

@cached_property
def chat(self) -> Any:
chat_completions = type('Completions', (), {'create': self.chat_completions_create})
return type('Chat', (), {'completions': chat_completions})

@classmethod
def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncOpenAI:
return cast(AsyncOpenAI, cls(completions=completions))

@classmethod
def create_mock_stream(
cls,
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]],
) -> AsyncOpenAI:
return cast(AsyncOpenAI, cls(stream=stream))

async def chat_completions_create( # pragma: lax no cover
self, *_args: Any, stream: bool = False, **kwargs: Any
) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]:
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})

if stream:
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
if isinstance(self.stream[0], Sequence):
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream[self.index])))
else:
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream)))
else:
assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided'
if isinstance(self.completions, Sequence):
raise_if_exception(self.completions[self.index])
response = cast(chat.ChatCompletion, self.completions[self.index])
else:
raise_if_exception(self.completions)
response = cast(chat.ChatCompletion, self.completions)
self.index += 1
return response


def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]:
if isinstance(async_open_ai, MockOpenAI):
return async_open_ai.chat_completion_kwargs
else: # pragma: no cover
raise RuntimeError('Not a MockOpenAI instance')


def completion_message(
message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None
) -> chat.ChatCompletion:
choices = [Choice(finish_reason='stop', index=0, message=message)]
if logprobs:
choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)]
return chat.ChatCompletion(
id='123',
choices=choices,
created=1704067200, # 2024-01-01
model='gpt-4o-123',
object='chat.completion',
usage=usage,
)


async def test_request_simple_success(allow_model_requests: None):
c = completion_message(
ChatCompletionMessage(content='world', role='assistant'),
Expand Down
Loading