-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
pydantic.py
66 lines (51 loc) · 2.03 KB
/
pydantic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Pydantic output parser."""
import json
from typing import Any, List, Optional, Type
from llama_index.core.output_parsers.base import ChainableOutputParser
from llama_index.core.output_parsers.utils import extract_json_str
from llama_index.core.types import Model
PYDANTIC_FORMAT_TMPL = """
Here's a JSON schema to follow:
{schema}
Output a valid JSON object but do not repeat the schema.
"""
class PydanticOutputParser(ChainableOutputParser):
"""Pydantic Output Parser.
Args:
output_cls (BaseModel): Pydantic output class.
"""
def __init__(
self,
output_cls: Type[Model],
excluded_schema_keys_from_format: Optional[List] = None,
pydantic_format_tmpl: str = PYDANTIC_FORMAT_TMPL,
) -> None:
"""Init params."""
self._output_cls = output_cls
self._excluded_schema_keys_from_format = excluded_schema_keys_from_format or []
self._pydantic_format_tmpl = pydantic_format_tmpl
@property
def output_cls(self) -> Type[Model]:
return self._output_cls
@property
def format_string(self) -> str:
"""Format string."""
return self.get_format_string(escape_json=True)
def get_format_string(self, escape_json: bool = True) -> str:
"""Format string."""
schema_dict = self._output_cls.schema()
for key in self._excluded_schema_keys_from_format:
del schema_dict[key]
schema_str = json.dumps(schema_dict)
output_str = self._pydantic_format_tmpl.format(schema=schema_str)
if escape_json:
return output_str.replace("{", "{{").replace("}", "}}")
else:
return output_str
def parse(self, text: str) -> Any:
"""Parse, validate, and correct errors programmatically."""
json_str = extract_json_str(text)
return self._output_cls.parse_raw(json_str)
def format(self, query: str) -> str:
"""Format a query with structured output formatting instructions."""
return query + "\n\n" + self.get_format_string(escape_json=True)