forked from openai/openai-agents-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent_output.py
144 lines (117 loc) · 5.22 KB
/
agent_output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel, TypeAdapter
from typing_extensions import TypedDict, get_args, get_origin
from . import _utils
from .exceptions import ModelBehaviorError, UserError
from .strict_schema import ensure_strict_json_schema
from .tracing import SpanError
_WRAPPER_DICT_KEY = "response"
@dataclass(init=False)
class AgentOutputSchema:
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
output_type: type[Any]
"""The type of the output."""
_type_adapter: TypeAdapter[Any]
"""A type adapter that wraps the output type, so that we can validate JSON."""
_is_wrapped: bool
"""Whether the output type is wrapped in a dictionary. This is generally done if the base
output type cannot be represented as a JSON Schema object.
"""
_output_schema: dict[str, Any]
"""The JSON schema of the output."""
strict_json_schema: bool
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input.
"""
def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
"""
Args:
output_type: The type of the output.
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend
setting this to True, as it increases the likelihood of correct JSON input.
"""
self.output_type = output_type
self.strict_json_schema = strict_json_schema
if output_type is None or output_type is str:
self._is_wrapped = False
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
return
# We should wrap for things that are not plain text, and for things that would definitely
# not be a JSON Schema object.
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type)
if self._is_wrapped:
OutputType = TypedDict(
"OutputType",
{
_WRAPPER_DICT_KEY: output_type, # type: ignore
},
)
self._type_adapter = TypeAdapter(OutputType)
self._output_schema = self._type_adapter.json_schema()
else:
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
if self.strict_json_schema:
self._output_schema = ensure_strict_json_schema(self._output_schema)
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
return self.output_type is None or self.output_type is str
def json_schema(self) -> dict[str, Any]:
"""The JSON schema of the output type."""
if self.is_plain_text():
raise UserError("Output type is plain text, so no JSON schema is available")
return self._output_schema
def validate_json(self, json_str: str, partial: bool = False) -> Any:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _utils.validate_json(json_str, self._type_adapter, partial)
if self._is_wrapped:
if not isinstance(validated, dict):
_utils.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Expected a dict, got {type(validated)}"},
)
)
raise ModelBehaviorError(
f"Expected a dict, got {type(validated)} for JSON: {json_str}"
)
if _WRAPPER_DICT_KEY not in validated:
_utils.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
)
)
raise ModelBehaviorError(
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}"
)
return validated[_WRAPPER_DICT_KEY]
return validated
def output_type_name(self) -> str:
"""The name of the output type."""
return _type_to_str(self.output_type)
def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
if not isinstance(t, type):
return False
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
origin = get_origin(t)
allowed_types = (BaseModel, dict)
# If it's a generic alias e.g. list[str], then we should check the origin type i.e. list
return issubclass(origin or t, allowed_types)
def _type_to_str(t: type[Any]) -> str:
origin = get_origin(t)
args = get_args(t)
if origin is None:
# It's a simple type like `str`, `int`, etc.
return t.__name__
elif args:
args_str = ", ".join(_type_to_str(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
else:
return str(t)