-
Notifications
You must be signed in to change notification settings - Fork 337
/
Correctness.ts
124 lines (110 loc) · 3.2 KB
/
Correctness.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import { MetadataMode } from "../Node.js";
import type { ServiceContext } from "../ServiceContext.js";
import { llmFromSettingsOrContext } from "../Settings.js";
import type { ChatMessage, LLM } from "../llm/types.js";
import { extractText } from "../llm/utils.js";
import { PromptMixin } from "../prompts/Mixin.js";
import type { CorrectnessSystemPrompt } from "./prompts.js";
import {
defaultCorrectnessSystemPrompt,
defaultUserPrompt,
} from "./prompts.js";
import type {
BaseEvaluator,
EvaluationResult,
EvaluatorParams,
EvaluatorResponseParams,
} from "./types.js";
import { defaultEvaluationParser } from "./utils.js";
type CorrectnessParams = {
serviceContext?: ServiceContext;
scoreThreshold?: number;
parserFunction?: (str: string) => [number, string];
};
/** Correctness Evaluator */
export class CorrectnessEvaluator extends PromptMixin implements BaseEvaluator {
private scoreThreshold: number;
private parserFunction: (str: string) => [number, string];
private llm: LLM;
private correctnessPrompt: CorrectnessSystemPrompt =
defaultCorrectnessSystemPrompt;
constructor(params?: CorrectnessParams) {
super();
this.llm = llmFromSettingsOrContext(params?.serviceContext);
this.correctnessPrompt = defaultCorrectnessSystemPrompt;
this.scoreThreshold = params?.scoreThreshold ?? 4.0;
this.parserFunction = params?.parserFunction ?? defaultEvaluationParser;
}
_updatePrompts(prompts: {
correctnessPrompt: CorrectnessSystemPrompt;
}): void {
if ("correctnessPrompt" in prompts) {
this.correctnessPrompt = prompts["correctnessPrompt"];
}
}
/**
*
* @param query Query to evaluate
* @param response Response to evaluate
* @param contexts Array of contexts
* @param reference Reference response
*/
async evaluate({
query,
response,
contexts,
reference,
}: EvaluatorParams): Promise<EvaluationResult> {
if (query === null || response === null) {
throw new Error("query, and response must be provided");
}
const messages: ChatMessage[] = [
{
role: "system",
content: this.correctnessPrompt(),
},
{
role: "user",
content: defaultUserPrompt({
query,
generatedAnswer: response,
referenceAnswer: reference || "(NO REFERENCE ANSWER SUPPLIED)",
}),
},
];
const evalResponse = await this.llm.chat({
messages,
});
const [score, reasoning] = this.parserFunction(
extractText(evalResponse.message.content),
);
return {
query: query,
response: response,
passing: score >= this.scoreThreshold || score === null,
score: score,
feedback: reasoning,
};
}
/**
* @param query Query to evaluate
* @param response Response to evaluate
*/
async evaluateResponse({
query,
response,
}: EvaluatorResponseParams): Promise<EvaluationResult> {
const responseStr = response?.response;
const contexts = [];
if (response) {
for (const node of response.sourceNodes || []) {
contexts.push(node.node.getContent(MetadataMode.ALL));
}
}
return this.evaluate({
query,
response: responseStr,
contexts,
});
}
}