diff --git a/llm/cli.py b/llm/cli.py index 2e11e2c8..d6f3d206 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -675,11 +675,6 @@ def read_prompt(): raise click.ClickException(str(ex)) extract = template_obj.extract extract_last = template_obj.extract_last - # Combine with template fragments/system_fragments - if template_obj.fragments: - fragments = [*template_obj.fragments, *fragments] - if template_obj.system_fragments: - system_fragments = [*template_obj.system_fragments, *system_fragments] if template_obj.schema_object: schema = template_obj.schema_object if template_obj.tools: @@ -711,6 +706,12 @@ def read_prompt(): raise click.ClickException(str(ex)) if model_id is None and template_obj.model: model_id = template_obj.model + # Combine with template fragments/system_fragments AFTER evaluation + # so that any variables in the fragments have been interpolated + if template_obj.fragments: + fragments = [*template_obj.fragments, *fragments] + if template_obj.system_fragments: + system_fragments = [*template_obj.system_fragments, *system_fragments] # Merge in any attachments if template_obj.attachments: attachments = [ diff --git a/llm/templates.py b/llm/templates.py index 657a4764..5a584af6 100644 --- a/llm/templates.py +++ b/llm/templates.py @@ -53,14 +53,29 @@ def evaluate( else: prompt = self.interpolate(self.prompt, params) system = self.interpolate(self.system, params) + + # Interpolate fragments + if self.fragments: + self.fragments = [interpolated for fragment in self.fragments if (interpolated := self.interpolate(fragment, params)) is not None] + if self.system_fragments: + self.system_fragments = [interpolated for fragment in self.system_fragments if (interpolated := self.interpolate(fragment, params)) is not None] + return prompt, system def vars(self) -> set: all_vars = set() + # Check prompt and system for text in [self.prompt, self.system]: if not text: continue all_vars.update(self.extract_vars(string.Template(text))) + # Check fragments and system_fragments + for fragment_list in [self.fragments, self.system_fragments]: + if not fragment_list: + continue + for fragment in fragment_list: + if fragment: + all_vars.update(self.extract_vars(string.Template(fragment))) return all_vars @classmethod diff --git a/tests/test_templates.py b/tests/test_templates.py index dcdabd9b..fe7a39fc 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -33,9 +33,7 @@ ), ), ) -def test_template_evaluate( - prompt, system, defaults, params, expected_prompt, expected_system, expected_error -): +def test_template_evaluate(prompt, system, defaults, params, expected_prompt, expected_system, expected_error): t = Template(name="t", prompt=prompt, system=system, defaults=defaults) if expected_error: with pytest.raises(Template.MissingVariables) as ex: @@ -47,6 +45,49 @@ def test_template_evaluate( assert system == expected_system +def test_template_evaluate_with_fragments(): + """Test that fragments and system_fragments support interpolation""" + t = Template( + name="t", + prompt="Main prompt: $input", + fragments=["Fragment 1: $input", "Fragment 2: $var2"], + system_fragments=["System fragment: $sys_var"], + ) + prompt, system = t.evaluate("test input", {"var2": "value2", "sys_var": "sys_value"}) + + # Check that prompt and system are correctly interpolated + assert prompt == "Main prompt: test input" + + # Check that fragments are interpolated + assert t.fragments == ["Fragment 1: test input", "Fragment 2: value2"] + assert t.system_fragments == ["System fragment: sys_value"] + + +def test_template_evaluate_with_fragments_missing_vars(): + """Test that missing variables in fragments raise an error""" + t = Template( + name="t", + prompt="Main prompt: $input", + fragments=["Fragment with $missing_var"], + ) + with pytest.raises(Template.MissingVariables) as ex: + t.evaluate("test input", {}) + assert "missing_var" in ex.value.args[0] + + +def test_template_vars_includes_fragments(): + """Test that the vars() method includes variables from fragments""" + t = Template( + name="t", + prompt="Prompt with $prompt_var", + system="System with $system_var", + fragments=["Fragment with $fragment_var"], + system_fragments=["System fragment with $sys_fragment_var"], + ) + vars = t.vars() + assert vars == {"prompt_var", "system_var", "fragment_var", "sys_fragment_var"} + + def test_templates_list_no_templates_found(): runner = CliRunner() result = runner.invoke(cli, ["templates", "list"]) @@ -58,15 +99,9 @@ def test_templates_list_no_templates_found(): def test_templates_list(templates_path, args): (templates_path / "one.yaml").write_text("template one", "utf-8") (templates_path / "two.yaml").write_text("template two", "utf-8") - (templates_path / "three.yaml").write_text( - "template three is very long " * 4, "utf-8" - ) - (templates_path / "four.yaml").write_text( - "'this one\n\nhas newlines in it'", "utf-8" - ) - (templates_path / "both.yaml").write_text( - "system: summarize this\nprompt: $input", "utf-8" - ) + (templates_path / "three.yaml").write_text("template three is very long " * 4, "utf-8") + (templates_path / "four.yaml").write_text("'this one\n\nhas newlines in it'", "utf-8") + (templates_path / "both.yaml").write_text("system: summarize this\nprompt: $input", "utf-8") (templates_path / "sys.yaml").write_text("system: Summarize this", "utf-8") (templates_path / "invalid.yaml").write_text("system2: This is invalid", "utf-8") runner = CliRunner() @@ -115,11 +150,7 @@ def test_templates_list(templates_path, args): "--schema", '{"properties": {"b": {"type": "string"}, "a": {"type": "string"}}}', ], - { - "schema_object": { - "properties": {"b": {"type": "string"}, "a": {"type": "string"}} - } - }, + {"schema_object": {"properties": {"b": {"type": "string"}, "a": {"type": "string"}}}}, None, ), # And fragments and system_fragments @@ -164,9 +195,7 @@ def test_templates_prompt_save(templates_path, args, expected, expected_error): yaml_data = yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8")) # Adjust attachment and attachment_types paths to be just the filename if "attachments" in yaml_data: - yaml_data["attachments"] = [ - os.path.basename(path) for path in yaml_data["attachments"] - ] + yaml_data["attachments"] = [os.path.basename(path) for path in yaml_data["attachments"]] for item in yaml_data.get("attachment_types", []): item["value"] = os.path.basename(item["value"]) assert yaml_data == expected @@ -177,18 +206,12 @@ def test_templates_prompt_save(templates_path, args, expected, expected_error): def test_templates_error_on_missing_schema(templates_path): runner = CliRunner() - runner.invoke( - cli, ["the-prompt", "--save", "prompt_no_schema"], catch_exceptions=False - ) + runner.invoke(cli, ["the-prompt", "--save", "prompt_no_schema"], catch_exceptions=False) # This should complain about no schema - result = runner.invoke( - cli, ["hi", "--schema", "t:prompt_no_schema"], catch_exceptions=False - ) + result = runner.invoke(cli, ["hi", "--schema", "t:prompt_no_schema"], catch_exceptions=False) assert result.output == "Error: Template 'prompt_no_schema' has no schema\n" # And this is just an invalid template - result2 = runner.invoke( - cli, ["hi", "--schema", "t:bad_template"], catch_exceptions=False - ) + result2 = runner.invoke(cli, ["hi", "--schema", "t:bad_template"], catch_exceptions=False) assert result2.output == "Error: Invalid template: bad_template\n" @@ -316,9 +339,7 @@ def test_execute_prompt_with_a_template( runner = CliRunner() result = runner.invoke( cli, - ["--no-stream", "-t", "template"] - + ([input_text] if input_text else []) - + extra_args, + ["--no-stream", "-t", "template"] + ([input_text] if input_text else []) + extra_args, catch_exceptions=False, ) if isinstance(expected_input, str): @@ -446,9 +467,7 @@ def register_tools(self, register): ("plugin", True, False), ), ) -def test_tools_in_templates( - source, expected_tool_success, expected_functions_success, httpx_mock, tmpdir -): +def test_tools_in_templates(source, expected_tool_success, expected_functions_success, httpx_mock, tmpdir): template_yaml = textwrap.dedent( """ name: test