Skip to content

Commit

Permalink
Add args to Jinja filters (#902)
Browse files Browse the repository at this point in the history
In the outlines docs, we have the example
```python
import outlines

def my_tool(arg1: str, arg2: int):
    """Tool description.

    The rest of the docstring
    """
    pass

@outlines.prompt
def tool_prompt(question, tool):
    """{{ question }}

    COMMANDS
    1. {{ tool | name }}: {{ tool | description }}, args: {{ tool | args }}

    {{ tool | source }}
    """

prompt = tool_prompt("Can you do something?", my_tool)
print(prompt)
```
However, when I tried running this code, it did not work because the
`args` filter used in `{{ tool | args }}` was not implemented. I
implemented the `args` filter so now this example works.

Now the args filter will output all of the arguments with the type
annotations and default values (if they are provided).
Example:
```python
from typing import List

def foo(x, y: str, z: List[int]=[1, 2, 3]):
    pass

@outlines.prompt
def tool_prompt(fn):
    """My args: {{ fn | args }}"""

prompt = tool_prompt(foo)
print(prompt)
```
which outputs
```python
My args: x, y: str, z: List[int] = [1, 2, 3]
```
  • Loading branch information
eitanturok committed May 17, 2024
1 parent 315d531 commit 3e291b1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/reference/prompting.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ Several projects (e.g.[Toolformer](https://arxiv.org/abs/2302.04761), [ViperGPT]
Can you do something?

COMMANDS
1. my_tool: Tool description, args: arg1:str, arg2:int
1. my_tool: Tool description., args: arg1: str, arg2: int

def my_tool(arg1: str, arg2: int):
"""Tool description.
Expand Down
13 changes: 13 additions & 0 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
env.filters["source"] = get_fn_source
env.filters["signature"] = get_fn_signature
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

jinja_template = env.from_string(cleaned_template)

Expand All @@ -226,6 +227,18 @@ def get_fn_name(fn: Callable):
return name


def get_fn_args(fn: Callable):
"""Returns the arguments of a function with annotations and default values if provided."""
if not callable(fn):
raise TypeError("The `args` filter only applies to callables.")

arg_str_list = []
signature = inspect.signature(fn)
arg_str_list = [str(param) for param in signature.parameters.values()]
arg_str = ", ".join(arg_str_list)
return arg_str


def get_fn_description(fn: Callable):
"""Returns the first line of a callable's docstring."""
if not callable(fn):
Expand Down
62 changes: 61 additions & 1 deletion tests/test_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Dict, List

import pytest
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -252,3 +252,63 @@ def source_ppt(model):

prompt = source_ppt(response)
assert prompt == '{\n "one": "a description",\n "two": ""\n}'


def test_prompt_args():
def no_args():
pass

def with_args(x, y, z):
pass

def with_annotations(x: bool, y: str, z: Dict[int, List[str]]):
pass

def with_defaults(x=True, y="Hi", z={4: ["I", "love", "outlines"]}):
pass

def with_annotations_and_defaults(
x: bool = True,
y: str = "Hi",
z: Dict[int, List[str]] = {4: ["I", "love", "outlines"]},
):
pass

def with_all(
x1,
y1,
z1,
x2: bool,
y2: str,
z2: Dict[int, List[str]],
x3=True,
y3="Hi",
z3={4: ["I", "love", "outlines"]},
x4: bool = True,
y4: str = "Hi",
z4: Dict[int, List[str]] = {4: ["I", "love", "outlines"]},
):
pass

@outlines.prompt
def args_prompt(fn):
"""args: {{ fn | args }}"""

assert args_prompt(no_args) == "args: "
assert args_prompt(with_args) == "args: x, y, z"
assert (
args_prompt(with_annotations)
== "args: x: bool, y: str, z: Dict[int, List[str]]"
)
assert (
args_prompt(with_defaults)
== "args: x=True, y='Hi', z={4: ['I', 'love', 'outlines']}"
)
assert (
args_prompt(with_annotations_and_defaults)
== "args: x: bool = True, y: str = 'Hi', z: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}"
)
assert (
args_prompt(with_all)
== "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}"
)

0 comments on commit 3e291b1

Please sign in to comment.