Skip to content

Commit

Permalink
produce a tool call result that works for all known models
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Feb 27, 2024
1 parent bd26ad7 commit 4337ade
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
20 changes: 18 additions & 2 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from typing import Optional
from typing import Literal, Optional, Required, TypedDict

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam

Expand Down Expand Up @@ -99,8 +99,23 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam:
"name": name,
}

class ChatCompletionToolMessageParamWithName(TypedDict, total=False):
content: Required[str]
"""The contents of the tool message."""

def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessageParam:
role: Required[Literal["tool"]]
"""The role of the messages author, in this case `tool`."""

tool_call_id: Required[str]
"""Tool call that this message is responding to."""

name: Optional[str]
"""The name of the tool."""




def tool_result(tool_call_id: str, content: str, name: str) -> ChatCompletionToolMessageParamWithName:
"""Create a tool result message.
Args:
Expand All @@ -112,6 +127,7 @@ def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessagePar
"""
return {
"role": "tool",
"name": name,
"content": content,
"tool_call_id": tool_call_id,
}
Expand Down
7 changes: 6 additions & 1 deletion chatlab/views/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..registry import FunctionRegistry, FunctionArgumentError, UnknownFunctionError

from ..messaging import assistant_function_call, function_result
from ..messaging import assistant_function_call, function_result, tool_result

class ToolCalled(AutoUpdate):
"""Once a tool has finished up, this is the view."""
Expand All @@ -27,6 +27,11 @@ def get_function_called_message(self):
return function_result(self.name, self.result)


def get_tool_called_message(self):
# NOTE: OpenAI has mismatched types where it doesn't include the `name`
# xref: https://github.com/openai/openai-python/issues/1078
return tool_result(tool_call_id=self.id, content=self.result, name=self.name)


class ToolArguments(AutoUpdate):
id: str
Expand Down

0 comments on commit 4337ade

Please sign in to comment.