-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
test_base.py
63 lines (51 loc) · 2.6 KB
/
test_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Test object index."""
from llama_index.legacy.indices.list.base import SummaryIndex
from llama_index.legacy.objects.base import ObjectIndex
from llama_index.legacy.objects.base_node_mapping import SimpleObjectNodeMapping
from llama_index.legacy.objects.tool_node_mapping import SimpleToolNodeMapping
from llama_index.legacy.service_context import ServiceContext
from llama_index.legacy.tools.function_tool import FunctionTool
def test_object_index(mock_service_context: ServiceContext) -> None:
"""Test object index."""
object_mapping = SimpleObjectNodeMapping.from_objects(["a", "b", "c"])
obj_index = ObjectIndex.from_objects(
["a", "b", "c"], object_mapping, index_cls=SummaryIndex
)
# should just retrieve everything
assert obj_index.as_retriever().retrieve("test") == ["a", "b", "c"]
# test adding an object
obj_index.insert_object("d")
assert obj_index.as_retriever().retrieve("test") == ["a", "b", "c", "d"]
def test_object_index_persist(mock_service_context: ServiceContext) -> None:
"""Test object index persist/load."""
object_mapping = SimpleObjectNodeMapping.from_objects(["a", "b", "c"])
obj_index = ObjectIndex.from_objects(
["a", "b", "c"], object_mapping, index_cls=SummaryIndex
)
obj_index.persist()
reloaded_obj_index = ObjectIndex.from_persist_dir()
assert obj_index._index.index_id == reloaded_obj_index._index.index_id
assert obj_index._index.index_struct == reloaded_obj_index._index.index_struct
assert (
obj_index._object_node_mapping.obj_node_mapping
== reloaded_obj_index._object_node_mapping.obj_node_mapping
)
# version where user passes in the object_node_mapping
reloaded_obj_index = ObjectIndex.from_persist_dir(
object_node_mapping=object_mapping
)
assert obj_index._index.index_id == reloaded_obj_index._index.index_id
assert obj_index._index.index_struct == reloaded_obj_index._index.index_struct
assert (
obj_index._object_node_mapping.obj_node_mapping
== reloaded_obj_index._object_node_mapping.obj_node_mapping
)
def test_object_index_with_tools(mock_service_context: ServiceContext) -> None:
"""Test object index with tools."""
tool1 = FunctionTool.from_defaults(fn=lambda x: x, name="test_tool")
tool2 = FunctionTool.from_defaults(fn=lambda x, y: x + y, name="test_tool2")
object_mapping = SimpleToolNodeMapping.from_objects([tool1, tool2])
obj_retriever = ObjectIndex.from_objects(
[tool1, tool2], object_mapping, index_cls=SummaryIndex
)
assert obj_retriever.as_retriever().retrieve("test") == [tool1, tool2]