/
utils.py
50 lines (43 loc) · 1.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from inspect import signature
from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast
from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model
def create_schema_from_function(
name: str,
func: Callable[..., Any],
additional_fields: Optional[
List[Union[Tuple[str, Type, Any], Tuple[str, Type]]]
] = None,
) -> Type[BaseModel]:
"""Create schema from function."""
fields = {}
params = signature(func).parameters
for param_name in params:
param_type = params[param_name].annotation
param_default = params[param_name].default
if param_type is params[param_name].empty:
param_type = Any
if param_default is params[param_name].empty:
# Required field
fields[param_name] = (param_type, FieldInfo())
elif isinstance(param_default, FieldInfo):
# Field with pydantic.Field as default value
fields[param_name] = (param_type, param_default)
else:
fields[param_name] = (param_type, FieldInfo(default=param_default))
additional_fields = additional_fields or []
for field_info in additional_fields:
if len(field_info) == 3:
field_info = cast(Tuple[str, Type, Any], field_info)
field_name, field_type, field_default = field_info
fields[field_name] = (field_type, FieldInfo(default=field_default))
elif len(field_info) == 2:
# Required field has no default value
field_info = cast(Tuple[str, Type], field_info)
field_name, field_type = field_info
fields[field_name] = (field_type, FieldInfo())
else:
raise ValueError(
f"Invalid additional field info: {field_info}. "
"Must be a tuple of length 2 or 3."
)
return create_model(name, **fields) # type: ignore