-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathsets.py
160 lines (136 loc) · 5.1 KB
/
sets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""A module for running OpenAI functions"""
from __future__ import annotations
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Callable, TYPE_CHECKING, overload
from .functions import FunctionResult, OpenAIFunction
from .wrapper import FunctionWrapper, WrapperConfig
if TYPE_CHECKING:
from ..json_type import JsonType
from ..openai_types import FunctionCall
class FunctionSet(ABC):
"""A skill set - a provider for a functions schema and a function runner"""
@property
@abstractmethod
def functions_schema(self) -> list[JsonType]:
"""Get the functions schema"""
@abstractmethod
def run_function(self, input_data: FunctionCall) -> FunctionResult:
"""Run the function
Args:
input_data (FunctionCall): The function call
Raises:
FunctionNotFoundError: If the function is not found
"""
def __call__(self, input_data: FunctionCall) -> JsonType:
"""Run the function with the given input data
Args:
input_data (FunctionCall): The input data from OpenAI
Returns:
JsonType: Your function's raw result
"""
return self.run_function(input_data).result
class MutableFunctionSet(FunctionSet):
"""A skill set that can be modified - functions can be added and removed"""
@abstractmethod
def _add_function(self, function: OpenAIFunction) -> None:
...
@overload
def add_function(self, function: OpenAIFunction) -> OpenAIFunction:
...
@overload
def add_function(
self,
function: Callable[..., Any],
*,
name: str | None = None,
description: str | None = None,
save_return: bool = True,
serialize: bool = True,
remove_call: bool = False,
interpret_as_response: bool = False,
) -> Callable[..., Any]:
...
@overload
def add_function(
self,
*,
name: str | None = None,
description: str | None = None,
save_return: bool = True,
serialize: bool = True,
remove_call: bool = False,
interpret_as_response: bool = False,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
...
def add_function(
self,
function: OpenAIFunction | Callable[..., Any] | None = None,
*,
name: str | None = None,
description: str | None = None,
save_return: bool = True,
serialize: bool = True,
remove_call: bool = False,
interpret_as_response: bool = False,
) -> Callable[[Callable[..., Any]], Callable[..., Any]] | Callable[..., Any]:
"""Add a function
Args:
function (OpenAIFunction | Callable[..., Any]): The function
name (str): The name of the function. Defaults to the function's name.
description (str): The description of the function. Defaults to getting
the short description from the function's docstring.
save_return (bool): Whether to send the return value of this
function to the AI. Defaults to True.
serialize (bool): Whether to serialize the return value of this
function. Defaults to True. Otherwise, the return value must be a
string.
remove_call (bool): Whether to remove the function call from the AI's
chat history. Defaults to False.
interpret_as_response (bool): Whether to interpret the return
value of this function as a response of the agent. Defaults to False.
Returns:
Callable[[Callable[..., Any]], Callable[..., Any]]: A decorator
Callable[..., Any]: The original function
"""
if isinstance(function, OpenAIFunction):
self._add_function(function)
return function
if callable(function):
self._add_function(
FunctionWrapper(
function,
WrapperConfig(
None, save_return, serialize, remove_call, interpret_as_response
),
name=name,
description=description,
)
)
return function
return partial(
self.add_function,
name=name,
description=description,
save_return=save_return,
serialize=serialize,
remove_call=remove_call,
interpret_as_response=interpret_as_response,
)
@abstractmethod
def _remove_function(self, name: str) -> None:
...
def remove_function(
self, function: str | OpenAIFunction | Callable[..., Any]
) -> None:
"""Remove a function
Args:
function (str | OpenAIFunction | Callable[..., Any]): The function
"""
if isinstance(function, str):
self._remove_function(function)
return
if isinstance(function, OpenAIFunction):
self._remove_function(function.name)
return
self._remove_function(function.__name__)