Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 5 additions & 23 deletions src/writer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from writerai.types.application_generate_content_params import Input
from writerai.types.chat import ChoiceMessage, ChoiceMessageGraphData, ChoiceMessageToolCall
from writerai.types.chat_chat_params import Message as WriterAIMessage
from writerai.types.chat_chat_params import MessageGraphData
from writerai.types.chat_chat_params import ToolFunctionTool as SDKFunctionTool
from writerai.types.chat_chat_params import ToolGraphTool as SDKGraphTool

Expand Down Expand Up @@ -94,22 +95,6 @@ class FunctionTool(Tool):
parameters: Dict[str, Dict[str, str]]


class PreparedAPIMessage(TypedDict, total=False):
role: Literal["user", "assistant", "system", "tool"]

content: Union[str, None]

name: Optional[str]

tool_call_id: Optional[str]

tool_calls: Optional[List[ChoiceMessageToolCall]]

graph_data: Optional[ChoiceMessageGraphData]

refusal: Optional[str]


def create_function_tool(
callable: Callable,
name: str,
Expand Down Expand Up @@ -1044,7 +1029,7 @@ def _clear_chunk_flag(chunk):
updated_last_message |= clear_chunk

@staticmethod
def _prepare_message(message: 'Conversation.Message') -> PreparedAPIMessage:
def _prepare_message(message: 'Conversation.Message') -> WriterAIMessage:
"""
Converts a message object stored in Conversation to a Writer AI SDK
`Message` model, suitable for calls to API.
Expand All @@ -1055,7 +1040,7 @@ def _prepare_message(message: 'Conversation.Message') -> PreparedAPIMessage:
"""
if not ("role" in message and "content" in message):
raise ValueError("Improper message format")
sdk_message = PreparedAPIMessage(
sdk_message = WriterAIMessage(
content=message["content"] or None,
role=message["role"]
)
Expand All @@ -1067,7 +1052,7 @@ def _prepare_message(message: 'Conversation.Message') -> PreparedAPIMessage:
sdk_message["tool_calls"] = cast(list, msg_tool_calls)
if msg_graph_data := message.get("graph_data"):
sdk_message["graph_data"] = cast(
ChoiceMessageGraphData,
MessageGraphData,
msg_graph_data
)
if msg_refusal := message.get("refusal"):
Expand Down Expand Up @@ -1350,13 +1335,10 @@ def _send_chat_request(
a Stream or a Chat object.
"""
client = WriterAIManager.acquire_client()
prepared_messages = cast(
Iterable[WriterAIMessage],
[
prepared_messages = [
self._prepare_message(message)
for message in self.messages
]
)
logging.debug(
"Attempting to request a message from LLM: " +
f"prepared messages – {prepared_messages}, " +
Expand Down