Skip to content

Commit d9d5299

Browse files
authored
Migrate from template to prompt arg while keeping backward compatibility (mem0ai#1066)
1 parent 12e6eaf commit d9d5299

File tree

9 files changed

+56
-42
lines changed

9 files changed

+56
-42
lines changed

Diff for: configs/full-stack.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ llm:
1515
max_tokens: 1000
1616
top_p: 1
1717
stream: false
18-
template: |
18+
prompt: |
1919
Use the following pieces of context to answer the query at the end.
2020
If you don't know the answer, just say that you don't know, don't try to make up an answer.
2121

Diff for: docs/api-reference/advanced/configuration.mdx

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ llm:
2626
top_p: 1
2727
stream: false
2828
api_key: sk-xxx
29-
template: |
29+
prompt: |
3030
Use the following pieces of context to answer the query at the end.
3131
If you don't know the answer, just say that you don't know, don't try to make up an answer.
3232
@@ -73,7 +73,7 @@ chunker:
7373
"max_tokens": 1000,
7474
"top_p": 1,
7575
"stream": false,
76-
"template": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
76+
"prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
7777
"system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
7878
"api_key": "sk-xxx"
7979
}
@@ -117,7 +117,7 @@ config = {
117117
'max_tokens': 1000,
118118
'top_p': 1,
119119
'stream': False,
120-
'template': (
120+
'prompt': (
121121
"Use the following pieces of context to answer the query at the end.\n"
122122
"If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
123123
"$context\n\nQuery: $query\n\nHelpful Answer:"
@@ -170,7 +170,7 @@ Alright, let's dive into what each key means in the yaml config above:
170170
- `max_tokens` (Integer): Controls how many tokens are used in the response.
171171
- `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
172172
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
173-
- `template` (String): A custom template for the prompt that the model uses to generate responses.
173+
- `prompt` (String): A prompt for the model to follow when generating responses, requires $context and $query variables.
174174
- `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
175175
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
176176
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1

Diff for: docs/examples/rest-api/create.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ llm:
3737
max_tokens: 1000
3838
top_p: 1
3939
stream: false
40-
template: |
40+
prompt: |
4141
Use the following pieces of context to answer the query at the end.
4242
If you don't know the answer, just say that you don't know, don't try to make up an answer.
4343

Diff for: embedchain/config/llm/base.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import re
23
from string import Template
34
from typing import Any, Dict, List, Optional
@@ -59,6 +60,7 @@ def __init__(
5960
self,
6061
number_documents: int = 3,
6162
template: Optional[Template] = None,
63+
prompt: Optional[Template] = None,
6264
model: Optional[str] = None,
6365
temperature: float = 0,
6466
max_tokens: int = 1000,
@@ -80,8 +82,11 @@ def __init__(
8082
context, defaults to 1
8183
:type number_documents: int, optional
8284
:param template: The `Template` instance to use as a template for
83-
prompt, defaults to None
85+
prompt, defaults to None (deprecated)
8486
:type template: Optional[Template], optional
87+
:param prompt: The `Template` instance to use as a template for
88+
prompt, defaults to None
89+
:type prompt: Optional[Template], optional
8590
:param model: Controls the OpenAI model used, defaults to None
8691
:type model: Optional[str], optional
8792
:param temperature: Controls the randomness of the model's output.
@@ -106,8 +111,16 @@ def __init__(
106111
contain $context and $query (and optionally $history)
107112
:raises ValueError: Stream is not boolean
108113
"""
109-
if template is None:
110-
template = DEFAULT_PROMPT_TEMPLATE
114+
if template is not None:
115+
logging.warning(
116+
"The `template` argument is deprecated and will be removed in a future version. "
117+
+ "Please use `prompt` instead."
118+
)
119+
if prompt is None:
120+
prompt = template
121+
122+
if prompt is None:
123+
prompt = DEFAULT_PROMPT_TEMPLATE
111124

112125
self.number_documents = number_documents
113126
self.temperature = temperature
@@ -120,37 +133,37 @@ def __init__(
120133
self.callbacks = callbacks
121134
self.api_key = api_key
122135

123-
if type(template) is str:
124-
template = Template(template)
136+
if type(prompt) is str:
137+
prompt = Template(prompt)
125138

126-
if self.validate_template(template):
127-
self.template = template
139+
if self.validate_prompt(prompt):
140+
self.prompt = prompt
128141
else:
129-
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
142+
raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
130143

131144
if not isinstance(stream, bool):
132145
raise ValueError("`stream` should be bool")
133146
self.stream = stream
134147
self.where = where
135148

136-
def validate_template(self, template: Template) -> bool:
149+
def validate_prompt(self, prompt: Template) -> bool:
137150
"""
138-
validate the template
151+
validate the prompt
139152
140-
:param template: the template to validate
141-
:type template: Template
153+
:param prompt: the prompt to validate
154+
:type prompt: Template
142155
:return: valid (true) or invalid (false)
143156
:rtype: bool
144157
"""
145-
return re.search(query_re, template.template) and re.search(context_re, template.template)
158+
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
146159

147-
def _validate_template_history(self, template: Template) -> bool:
160+
def _validate_prompt_history(self, prompt: Template) -> bool:
148161
"""
149-
validate the template with history
162+
validate the prompt with history
150163
151-
:param template: the template to validate
152-
:type template: Template
164+
:param prompt: the prompt to validate
165+
:type prompt: Template
153166
:return: valid (true) or invalid (false)
154167
:rtype: bool
155168
"""
156-
return re.search(history_re, template.template)
169+
return re.search(history_re, prompt.template)

Diff for: embedchain/llm/base.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,19 @@ def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[
7474
if web_search_result:
7575
context_string = self._append_search_and_context(context_string, web_search_result)
7676

77-
template_contains_history = self.config._validate_template_history(self.config.template)
78-
if template_contains_history:
79-
# Template contains history
77+
prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
78+
if prompt_contains_history:
79+
# Prompt contains history
8080
# If there is no history yet, we insert `- no history -`
81-
prompt = self.config.template.substitute(
81+
prompt = self.config.prompt.substitute(
8282
context=context_string, query=input_query, history=self.history or "- no history -"
8383
)
84-
elif self.history and not template_contains_history:
85-
# History is present, but not included in the template.
86-
# check if it's the default template without history
84+
elif self.history and not prompt_contains_history:
85+
# History is present, but not included in the prompt.
86+
# check if it's the default prompt without history
8787
if (
88-
not self.config._validate_template_history(self.config.template)
89-
and self.config.template.template == DEFAULT_PROMPT
88+
not self.config._validate_prompt_history(self.config.prompt)
89+
and self.config.prompt.template == DEFAULT_PROMPT
9090
):
9191
# swap in the template with history
9292
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
@@ -95,12 +95,12 @@ def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[
9595
else:
9696
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
9797
logging.warning(
98-
"Your bot contains a history, but template does not include `$history` key. History is ignored."
98+
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
9999
)
100-
prompt = self.config.template.substitute(context=context_string, query=input_query)
100+
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
101101
else:
102102
# basic use case, no history.
103-
prompt = self.config.template.substitute(context=context_string, query=input_query)
103+
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
104104
return prompt
105105

106106
def _append_search_and_context(self, context: str, web_search_result: str) -> str:
@@ -191,7 +191,7 @@ def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = N
191191
return contexts
192192

193193
if self.is_docs_site_instance:
194-
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
194+
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
195195
self.config.number_documents = 5
196196
k = {}
197197
if self.online:
@@ -242,7 +242,7 @@ def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = No
242242
self.config = config
243243

244244
if self.is_docs_site_instance:
245-
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
245+
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
246246
self.config.number_documents = 5
247247
k = {}
248248
if self.online:

Diff for: embedchain/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def validate_config(config_data):
396396
Optional("top_p"): Or(float, int),
397397
Optional("stream"): bool,
398398
Optional("template"): str,
399+
Optional("prompt"): str,
399400
Optional("system_prompt"): str,
400401
Optional("deployment_name"): str,
401402
Optional("where"): dict,

Diff for: tests/helper_classes/test_json_serializable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,4 @@ def test_special_subclasses(self):
7676
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
7777
s = config.serialize()
7878
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
79-
self.assertEqual(config.template.template, new_config.template.template)
79+
self.assertEqual(config.prompt.template, new_config.prompt.template)

Diff for: tests/llm/test_base_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_is_stream_bool():
2525
def test_template_string_gets_converted_to_Template_instance():
2626
config = BaseLlmConfig(template="test value $query $context")
2727
llm = BaseLlm(config=config)
28-
assert isinstance(llm.config.template, Template)
28+
assert isinstance(llm.config.prompt, Template)
2929

3030

3131
def test_is_get_llm_model_answer_implemented():

Diff for: tests/llm/test_generate_prompt.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ def test_generate_prompt_with_contexts_list(self):
5353
result = self.app.llm.generate_prompt(input_query, contexts)
5454

5555
# Assert
56-
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
56+
expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
5757
self.assertEqual(result, expected_result)
5858

5959
def test_generate_prompt_with_history(self):
6060
"""
6161
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
6262
"""
6363
config = BaseLlmConfig()
64-
config.template = Template("Context: $context | Query: $query | History: $history")
64+
config.prompt = Template("Context: $context | Query: $query | History: $history")
6565
self.app.llm.config = config
6666
self.app.llm.set_history(["Past context 1", "Past context 2"])
6767
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])

0 commit comments

Comments
 (0)