-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
evaluator.py
134 lines (110 loc) · 4.5 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Retrieval evaluators."""
from typing import Any, List, Optional, Sequence, Tuple
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.bridge.pydantic import Field
from llama_index.core.evaluation.retrieval.base import (
BaseRetrievalEvaluator,
RetrievalEvalMode,
)
from llama_index.core.evaluation.retrieval.metrics_base import (
BaseRetrievalMetric,
)
from llama_index.core.indices.base_retriever import BaseRetriever
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import ImageNode, TextNode
class RetrieverEvaluator(BaseRetrievalEvaluator):
"""Retriever evaluator.
This module will evaluate a retriever using a set of metrics.
Args:
metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
retriever: Retriever to evaluate.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
"""
retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
default=None, description="Optional post-processor"
)
def __init__(
self,
metrics: Sequence[BaseRetrievalMetric],
retriever: BaseRetriever,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(
metrics=metrics,
retriever=retriever,
node_postprocessors=node_postprocessors,
**kwargs,
)
async def _aget_retrieved_ids_and_texts(
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
) -> Tuple[List[str], List[str]]:
"""Get retrieved ids and texts, potentially applying a post-processor."""
retrieved_nodes = await self.retriever.aretrieve(query)
if self.node_postprocessors:
for node_postprocessor in self.node_postprocessors:
retrieved_nodes = node_postprocessor.postprocess_nodes(
retrieved_nodes, query_str=query
)
return (
[node.node.node_id for node in retrieved_nodes],
[node.node.text for node in retrieved_nodes],
)
class MultiModalRetrieverEvaluator(BaseRetrievalEvaluator):
"""Retriever evaluator.
This module will evaluate a retriever using a set of metrics.
Args:
metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
retriever: Retriever to evaluate.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
"""
retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
default=None, description="Optional post-processor"
)
def __init__(
self,
metrics: Sequence[BaseRetrievalMetric],
retriever: BaseRetriever,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(
metrics=metrics,
retriever=retriever,
node_postprocessors=node_postprocessors,
**kwargs,
)
async def _aget_retrieved_ids_texts(
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
) -> Tuple[List[str], List[str]]:
"""Get retrieved ids."""
retrieved_nodes = await self.retriever.aretrieve(query)
image_nodes: List[ImageNode] = []
text_nodes: List[TextNode] = []
if self.node_postprocessors:
for node_postprocessor in self.node_postprocessors:
retrieved_nodes = node_postprocessor.postprocess_nodes(
retrieved_nodes, query_str=query
)
for scored_node in retrieved_nodes:
node = scored_node.node
if isinstance(node, ImageNode):
image_nodes.append(node)
if node.text:
text_nodes.append(node)
if mode == "text":
return (
[node.node_id for node in text_nodes],
[node.text for node in text_nodes],
)
elif mode == "image":
return (
[node.node_id for node in image_nodes],
[node.text for node in image_nodes],
)
else:
raise ValueError("Unsupported mode.")