-
Notifications
You must be signed in to change notification settings - Fork 337
/
Mixin.ts
89 lines (71 loc) · 2.26 KB
/
Mixin.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
type PromptsDict = Record<string, any>;
type ModuleDict = Record<string, any>;
export class PromptMixin {
/**
* Validates the prompt keys and module keys
* @param promptsDict
* @param moduleDict
*/
validatePrompts(promptsDict: PromptsDict, moduleDict: ModuleDict): void {
for (const key in promptsDict) {
if (key.includes(":")) {
throw new Error(`Prompt key ${key} cannot contain ':'.`);
}
}
for (const key in moduleDict) {
if (key.includes(":")) {
throw new Error(`Module key ${key} cannot contain ':'.`);
}
}
}
/**
* Returns all prompts from the mixin and its modules
*/
getPrompts(): PromptsDict {
const promptsDict: PromptsDict = this._getPrompts();
const moduleDict = this._getPromptModules();
this.validatePrompts(promptsDict, moduleDict);
const allPrompts: PromptsDict = { ...promptsDict };
for (const [module_name, prompt_module] of Object.entries(moduleDict)) {
for (const [key, prompt] of Object.entries(prompt_module.getPrompts())) {
allPrompts[`${module_name}:${key}`] = prompt;
}
}
return allPrompts;
}
/**
* Updates the prompts in the mixin and its modules
* @param promptsDict
*/
updatePrompts(promptsDict: PromptsDict): void {
const promptModules = this._getPromptModules();
this._updatePrompts(promptsDict);
const subPromptDicts: Record<string, PromptsDict> = {};
for (const key in promptsDict) {
if (key.includes(":")) {
const [module_name, sub_key] = key.split(":");
if (!subPromptDicts[module_name]) {
subPromptDicts[module_name] = {};
}
subPromptDicts[module_name][sub_key] = promptsDict[key];
}
}
for (const [module_name, subPromptDict] of Object.entries(subPromptDicts)) {
if (!promptModules[module_name]) {
throw new Error(`Module ${module_name} not found.`);
}
const moduleToUpdate = promptModules[module_name];
moduleToUpdate.updatePrompts(subPromptDict);
}
}
// Must be implemented by subclasses
protected _getPrompts(): PromptsDict {
return {};
}
protected _getPromptModules(): Record<string, any> {
return {};
}
protected _updatePrompts(promptsDict: PromptsDict): void {
return;
}
}