-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
test_refine.py
142 lines (121 loc) · 4.42 KB
/
test_refine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from collections import OrderedDict
from typing import Any, Dict, Optional, Type, cast
import pytest
from llama_index.legacy.bridge.pydantic import BaseModel
from llama_index.legacy.callbacks import CallbackManager
from llama_index.legacy.response_synthesizers import Refine
from llama_index.legacy.response_synthesizers.refine import (
StructuredRefineResponse,
)
from llama_index.legacy.service_context import ServiceContext
from llama_index.legacy.types import BasePydanticProgram
class MockRefineProgram(BasePydanticProgram):
"""
Runs the query on the LLM as normal and always returns the answer with
query_satisfied=True. In effect, doesn't do any answer filtering.
"""
def __init__(self, input_to_query_satisfied: Dict[str, bool]):
self._input_to_query_satisfied = input_to_query_satisfied
@property
def output_cls(self) -> Type[BaseModel]:
return StructuredRefineResponse
def __call__(
self,
*args: Any,
context_str: Optional[str] = None,
context_msg: Optional[str] = None,
**kwargs: Any
) -> StructuredRefineResponse:
input_str = context_str or context_msg
input_str = cast(str, input_str)
query_satisfied = self._input_to_query_satisfied[input_str]
return StructuredRefineResponse(
answer=input_str, query_satisfied=query_satisfied
)
async def acall(
self,
*args: Any,
context_str: Optional[str] = None,
context_msg: Optional[str] = None,
**kwargs: Any
) -> StructuredRefineResponse:
input_str = context_str or context_msg
input_str = cast(str, input_str)
query_satisfied = self._input_to_query_satisfied[input_str]
return StructuredRefineResponse(
answer=input_str, query_satisfied=query_satisfied
)
@pytest.fixture()
def mock_refine_service_context(patch_llm_predictor: Any) -> ServiceContext:
cb_manager = CallbackManager([])
return ServiceContext.from_defaults(
llm_predictor=patch_llm_predictor,
callback_manager=cb_manager,
)
@pytest.fixture()
def refine_instance(mock_refine_service_context: ServiceContext) -> Refine:
return Refine(
service_context=mock_refine_service_context,
streaming=False,
verbose=True,
structured_answer_filtering=True,
)
def test_constructor_args(mock_refine_service_context: ServiceContext) -> None:
with pytest.raises(ValueError):
# can't construct refine with both streaming and answer filtering
Refine(
service_context=mock_refine_service_context,
streaming=True,
structured_answer_filtering=True,
)
with pytest.raises(ValueError):
# can't construct refine with a program factory but not answer filtering
Refine(
service_context=mock_refine_service_context,
program_factory=lambda _: MockRefineProgram({}),
structured_answer_filtering=False,
)
@pytest.mark.asyncio()
async def test_answer_filtering_one_answer(
mock_refine_service_context: ServiceContext,
) -> None:
input_to_query_satisfied = OrderedDict(
[
("input1", False),
("input2", True),
("input3", False),
]
)
def program_factory(*args: Any, **kwargs: Any) -> MockRefineProgram:
return MockRefineProgram(input_to_query_satisfied)
refine_instance = Refine(
service_context=mock_refine_service_context,
structured_answer_filtering=True,
program_factory=program_factory,
)
res = await refine_instance.aget_response(
"question", list(input_to_query_satisfied.keys())
)
assert res == "input2"
@pytest.mark.asyncio()
async def test_answer_filtering_no_answers(
mock_refine_service_context: ServiceContext,
) -> None:
input_to_query_satisfied = OrderedDict(
[
("input1", False),
("input2", False),
("input3", False),
]
)
def program_factory(*args: Any, **kwargs: Any) -> MockRefineProgram:
return MockRefineProgram(input_to_query_satisfied)
refine_instance = Refine(
service_context=mock_refine_service_context,
structured_answer_filtering=True,
program_factory=program_factory,
)
res = await refine_instance.aget_response(
"question", list(input_to_query_satisfied.keys())
)
assert res == "Empty Response"