Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 59 additions & 74 deletions src/writerai/lib/_parsing/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,25 @@

import pydantic

from writerai.types.shared_params.tool_param import ToolParam

from .._tools import PydanticFunctionTool
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_given
from ..._utils import is_dict, is_given
from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, is_dataclass_like_type
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ...types.parsed_chat import (
ChatCompletion,
ParsedChatCompletion,
ChatCompletionMessage,
ParsedFunctionToolCall,
ParsedChatCompletionChoice,
ParsedChatCompletionMessage,
)
from ...types.shared_params import FunctionDefinition
from ...types.chat_completion import ChatCompletion
from ...types.shared_params import ToolParam as ChatCompletionToolParam, FunctionDefinition
from ...types.chat_chat_params import ResponseFormat as ResponseFormatParam
from ...types.shared.tool_call import Function
from ...types.chat_completion_message import ChatCompletionMessage
from ...types.parsed_function_tool_call import (
ParsedFunction,
ParsedFunctionToolCall,
)
from ...types.parsed_function_tool_call import ParsedFunction

ResponseFormatT = TypeVar(
"ResponseFormatT",
Expand All @@ -38,7 +35,7 @@


def validate_input_tools(
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> None:
if not is_given(tools):
return
Expand All @@ -58,9 +55,8 @@ def validate_input_tools(

def parse_chat_completion(
*,
# response_format: type[ResponseFormatT] | chat_chat_params.ResponseFormat | NotGiven,
response_format: type[ResponseFormatT] | NotGiven,
input_tools: Iterable[ToolParam] | NotGiven,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
chat_completion: ChatCompletion | ParsedChatCompletion[object],
) -> ParsedChatCompletion[ResponseFormatT]:
if is_given(input_tools):
Expand Down Expand Up @@ -113,7 +109,7 @@ def parse_chat_completion(
response_format=response_format,
message=message,
),
"tool_calls": tool_calls,
"tool_calls": tool_calls if tool_calls else None,
},
},
)
Expand All @@ -131,11 +127,13 @@ def parse_chat_completion(
)


def get_input_tool_by_name(*, input_tools: list[ToolParam], name: str) -> ToolParam | None:
def get_input_tool_by_name(*, input_tools: list[ChatCompletionToolParam], name: str) -> ChatCompletionToolParam | None:
return next((t for t in input_tools if t.get("function", {}).get("name") == name), None)


def parse_function_tool_arguments(*, input_tools: list[ToolParam], function: Function | ParsedFunction) -> object:
def parse_function_tool_arguments(
*, input_tools: list[ChatCompletionToolParam], function: Function | ParsedFunction
) -> object:
assert function.name is not None
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
if not input_tool:
Expand All @@ -155,8 +153,7 @@ def parse_function_tool_arguments(*, input_tools: list[ToolParam], function: Fun

def maybe_parse_content(
*,
# response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
response_format: type[ResponseFormatT] | NotGiven,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
) -> ResponseFormatT | None:
if has_rich_response_format(response_format) and message.content and not message.refusal:
Expand All @@ -166,8 +163,7 @@ def maybe_parse_content(


def solve_response_format_t(
# response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
response_format: type[ResponseFormatT] | NotGiven,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> type[ResponseFormatT]:
"""Return the runtime type for the given response format.

Expand All @@ -182,9 +178,8 @@ def solve_response_format_t(

def has_parseable_input(
*,
# response_format: type | ResponseFormatParam | NotGiven,
response_format: type | NotGiven,
input_tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
response_format: type | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> bool:
if has_rich_response_format(response_format):
return True
Expand All @@ -197,30 +192,27 @@ def has_parseable_input(


def has_rich_response_format(
# response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
response_format: type[ResponseFormatT] | NotGiven,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> TypeGuard[type[ResponseFormatT]]:
if not is_given(response_format):
return False

# if is_response_format_param(response_format):
# return False
if is_response_format_param(response_format):
return False

return True


# def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
# return is_dict(response_format)
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
return is_dict(response_format)


def is_parseable_tool(input_tool: ToolParam) -> bool:
def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return True

# FIXME: `strict` currently missing in the schema definition
# return cast(FunctionDefinition, input_fn).get("strict") or False
return False
return cast(FunctionDefinition, input_fn).get("strict") or False # type: ignore


def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
Expand All @@ -236,43 +228,36 @@ def _parse_content(response_format: type[ResponseFormatT], content: str) -> Resp
raise TypeError(f"Unable to automatically parse response format type {response_format}")


# def type_to_response_format_param(
# response_format: type | completion_create_params.ResponseFormat | NotGiven,
# ) -> ResponseFormatParam | NotGiven:
# if not is_given(response_format):
# return NOT_GIVEN

# if is_response_format_param(response_format):
# return response_format

# # type checkers don't narrow the negation of a `TypeGuard` as it isn't
# # a safe default behaviour but we know that at this point the `response_format`
# # can only be a `type`
# response_format = cast(type, response_format)

# json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None

# if is_basemodel_type(response_format):
# name = response_format.__name__
# json_schema_type = response_format
# elif is_dataclass_like_type(response_format):
# name = response_format.__name__
# json_schema_type = pydantic.TypeAdapter(response_format)
# else:
# raise TypeError(f"Unsupported response_format type - {response_format}")


# return {
# "type": "json_schema",
# "json_schema": {
# "schema": to_strict_json_schema(json_schema_type),
# "name": name,
# "strict": True,
# },
# }
def type_to_response_format_param(
response_format: type | NotGiven,
) -> NotGiven:
if is_given(response_format):
raise NotImplementedError("Support for response_format is not implemented yet")
return NOT_GIVEN
response_format: type | ResponseFormatParam | NotGiven,
) -> ResponseFormatParam | NotGiven:
if not is_given(response_format):
return NOT_GIVEN

if is_response_format_param(response_format):
return response_format

# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
response_format = cast(type, response_format)

json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None

if is_basemodel_type(response_format):
name = response_format.__name__
json_schema_type = response_format
elif is_dataclass_like_type(response_format):
name = response_format.__name__
json_schema_type = pydantic.TypeAdapter(response_format)
else:
raise TypeError(f"Unsupported response_format type - {response_format}")

return {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}
Loading