-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtogglable_set.py
86 lines (71 loc) · 2.87 KB
/
togglable_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""A function set disabled by default that exposes a function to enable it."""
from __future__ import annotations
from typing import TYPE_CHECKING
from ..exceptions import FunctionNotFoundError
from .basic_set import BasicFunctionSet
from .functions import FunctionResult
from .functions import OpenAIFunction
if TYPE_CHECKING:
from ..json_type import JsonType
from ..openai_types import FunctionCall
class TogglableSet(BasicFunctionSet):
"""A function set that is disabled by default and can be enabled by the AI.
Args:
enable_function_name (str): The name of the function to enable the set
enable_function_description (str, optional): The description of the enable
function. By default no description is provided.
functions (list[OpenAIFunction], optional): The functions in the set.
"""
def __init__(
self,
enable_function_name: str,
enable_function_description: str | None = None,
functions: list[OpenAIFunction] | None = None,
) -> None:
super().__init__(functions)
self.enabled = False
self.enable_function_name = enable_function_name
self.enable_function_description = enable_function_description
def enable(self) -> None:
"""Enable the function set."""
self.enabled = True
@property
def _enable_function_schema(self) -> dict[str, JsonType]:
"""Get the schema for the enable function.
Returns:
dict[str, JsonType]: The schema for the enable function
"""
schema: dict[str, JsonType] = {
"name": self.enable_function_name,
"parameters": {
"type": "object",
"properties": {},
},
}
if self.enable_function_description:
schema["description"] = self.enable_function_description
return schema
@property
def functions_schema(self) -> list[JsonType]:
"""Get the functions schema, in the format OpenAI expects
Returns:
JsonType: The schema of all the available functions
"""
if self.enabled:
return super().functions_schema
return [self._enable_function_schema]
def run_function(self, input_data: FunctionCall) -> FunctionResult:
"""Run the function, enabling the set if the enable function is called.
Args:
input_data (FunctionCall): The function call
Returns:
FunctionResult: The function output
Raises:
FunctionNotFoundError: If the function is not found
"""
if not self.enabled:
if input_data["name"] == self.enable_function_name:
self.enable()
return FunctionResult(self.enable_function_name, None, True)
raise FunctionNotFoundError(input_data["name"])
return super().run_function(input_data)