Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 85 additions & 31 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from contextlib import ExitStack, contextmanager
from copy import deepcopy
from typing import Any, Dict, Tuple, Type, Union # noqa: UP035
import importlib

from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

import dsp
from dspy.adapters.image_utils import Image
from dspy.adapters.image_utils import Image # noqa: F401
from dspy.signatures.field import InputField, OutputField, new_to_old_field


Expand Down Expand Up @@ -355,7 +356,7 @@ def make_signature(
if type_ is None:
type_ = str
# if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type):
if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias)):
if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias, typing._SpecialForm)):
raise ValueError(f"Field types must be types, not {type(type_)}")
if not isinstance(field, FieldInfo):
raise ValueError(f"Field values must be Field instances, not {type(field)}")
Expand Down Expand Up @@ -400,53 +401,106 @@ def _parse_arg_string(string: str, names=None) -> Dict[str, str]:


def _parse_type_node(node, names=None) -> Any:
"""Recursively parse an AST node representing a type annotation.

without using structural pattern matching introduced in Python 3.10.
"""
"""Recursively parse an AST node representing a type annotation."""

if names is None:
names = typing.__dict__
names = dict(typing.__dict__)
names['NoneType'] = type(None)

def resolve_name(id_: str):
# Check if it's a built-in known type or in the provided names
if id_ in names:
return names[id_]

# Common built-in types
builtin_types = [int, str, float, bool, list, tuple, dict, set, frozenset, complex, bytes, bytearray]

# Try PIL Image if 'Image' encountered
if 'Image' not in names:
try:
from PIL import Image
names['Image'] = Image
except ImportError:
pass

# If we have PIL Image and id_ is 'Image', return it
if 'Image' in names and id_ == 'Image':
return names['Image']

# Check if it matches any known built-in type by name
for t in builtin_types:
if t.__name__ == id_:
return t

# Attempt to import a module with this name dynamically
# This allows handling of module-based annotations like `dspy.Image`.
try:
mod = importlib.import_module(id_)
names[id_] = mod
return mod
except ImportError:
pass

# If we don't know the type or module, raise an error
raise ValueError(f"Unknown name: {id_}")

if isinstance(node, ast.Module):
body = node.body
if len(body) != 1:
raise ValueError(f"Code is not syntactically valid: {node}")
return _parse_type_node(body[0], names)
if len(node.body) != 1:
raise ValueError(f"Code is not syntactically valid: {ast.dump(node)}")
return _parse_type_node(node.body[0], names)

if isinstance(node, ast.Expr):
value = node.value
return _parse_type_node(value, names)
return _parse_type_node(node.value, names)

if isinstance(node, ast.Name):
id_ = node.id
if id_ in names:
return names[id_]
return resolve_name(node.id)

for type_ in [int, str, float, bool, list, tuple, dict, Image]:
if type_.__name__ == id_:
return type_
raise ValueError(f"Unknown name: {id_}")
if isinstance(node, ast.Attribute):
base = _parse_type_node(node.value, names)
attr_name = node.attr
if hasattr(base, attr_name):
return getattr(base, attr_name)
else:
raise ValueError(f"Unknown attribute: {attr_name} on {base}")

if isinstance(node, ast.Subscript):
base_type = _parse_type_node(node.value, names)
arg_type = _parse_type_node(node.slice, names)
return base_type[arg_type]
slice_node = node.slice
if isinstance(slice_node, ast.Index): # For older Python versions
slice_node = slice_node.value

if isinstance(slice_node, ast.Tuple):
arg_types = tuple(_parse_type_node(elt, names) for elt in slice_node.elts)
else:
arg_types = (_parse_type_node(slice_node, names),)

# Special handling for Union, Optional
if base_type is typing.Union:
return typing.Union[arg_types]
if base_type is typing.Optional:
if len(arg_types) != 1:
raise ValueError("Optional must have exactly one type argument")
return typing.Optional[arg_types[0]]

return base_type[arg_types]

if isinstance(node, ast.Tuple):
elts = node.elts
return tuple(_parse_type_node(elt, names) for elt in elts)
return tuple(_parse_type_node(elt, names) for elt in node.elts)

if isinstance(node, ast.Constant):
return node.value

if isinstance(node, ast.Call) and node.func.id == "Field":
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "Field":
keys = [kw.arg for kw in node.keywords]
values = [kw.value.value for kw in node.keywords]
values = []
for kw in node.keywords:
if isinstance(kw.value, ast.Constant):
values.append(kw.value.value)
else:
values.append(_parse_type_node(kw.value, names))
return Field(**dict(zip(keys, values)))

if isinstance(node, ast.Attribute) and node.attr == "Image":
return Image

raise ValueError(f"Code is not syntactically valid: {node}")

raise ValueError(f"Unhandled AST node type in annotation: {ast.dump(node)}")

def infer_prefix(attribute_name: str) -> str:
"""Infer a prefix from an attribute name."""
Expand Down
124 changes: 123 additions & 1 deletion tests/signatures/test_signature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import textwrap
from typing import List
from typing import Any, Dict, List, Optional, Tuple, Union

import pydantic
import pytest
Expand Down Expand Up @@ -279,3 +279,125 @@ class CustomSignature2(dspy.Signature):
assert CustomSignature2.instructions == "I am a malicious instruction."
assert CustomSignature2.fields["sentence"].json_schema_extra["desc"] == "I am an malicious input!"
assert CustomSignature2.fields["sentiment"].json_schema_extra["prefix"] == "Sentiment:"


def test_typed_signatures_basic_types():
# Simple built-in types
sig = Signature("input1: int, input2: str -> output: float")
assert "input1" in sig.input_fields
assert sig.input_fields["input1"].annotation == int
assert "input2" in sig.input_fields
assert sig.input_fields["input2"].annotation == str
assert "output" in sig.output_fields
assert sig.output_fields["output"].annotation == float


def test_typed_signatures_generics():
# More complex generic types
sig = Signature("input_list: List[int], input_dict: Dict[str, float] -> output_tuple: Tuple[str, int]")
assert "input_list" in sig.input_fields
assert sig.input_fields["input_list"].annotation == List[int]
assert "input_dict" in sig.input_fields
assert sig.input_fields["input_dict"].annotation == Dict[str, float]
assert "output_tuple" in sig.output_fields
assert sig.output_fields["output_tuple"].annotation == Tuple[str, int]


def test_typed_signatures_unions_and_optionals():
sig = Signature("input_opt: Optional[str], input_union: Union[int, None] -> output_union: Union[int, str]")
assert "input_opt" in sig.input_fields
# Optional[str] is actually Union[str, None]
# Depending on the environment, it might resolve to Union[str, None] or Optional[str], either is correct.
# We'll just check for a Union containing str and NoneType:
input_opt_annotation = sig.input_fields["input_opt"].annotation
assert (input_opt_annotation == Optional[str] or
(getattr(input_opt_annotation, '__origin__', None) is Union and str in input_opt_annotation.__args__ and type(None) in input_opt_annotation.__args__))

assert "input_union" in sig.input_fields
input_union_annotation = sig.input_fields["input_union"].annotation
assert (getattr(input_union_annotation, '__origin__', None) is Union and
int in input_union_annotation.__args__ and type(None) in input_union_annotation.__args__)

assert "output_union" in sig.output_fields
output_union_annotation = sig.output_fields["output_union"].annotation
assert (getattr(output_union_annotation, '__origin__', None) is Union and
int in output_union_annotation.__args__ and str in output_union_annotation.__args__)


def test_typed_signatures_any():
sig = Signature("input_any: Any -> output_any: Any")
assert "input_any" in sig.input_fields
assert sig.input_fields["input_any"].annotation == Any
assert "output_any" in sig.output_fields
assert sig.output_fields["output_any"].annotation == Any


def test_typed_signatures_nested():
# Nested generics and unions
sig = Signature("input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]")
input_nested_ann = sig.input_fields["input_nested"].annotation
assert getattr(input_nested_ann, '__origin__', None) is list
assert len(input_nested_ann.__args__) == 1
union_arg = input_nested_ann.__args__[0]
assert getattr(union_arg, '__origin__', None) is Union
assert str in union_arg.__args__ and int in union_arg.__args__

output_nested_ann = sig.output_fields["output_nested"].annotation
assert getattr(output_nested_ann, '__origin__', None) is tuple
assert output_nested_ann.__args__[0] == int
# The second arg is Optional[float], which is Union[float, None]
second_arg = output_nested_ann.__args__[1]
assert getattr(second_arg, '__origin__', None) is Union
assert float in second_arg.__args__ and type(None) in second_arg.__args__
# The third arg is List[str]
third_arg = output_nested_ann.__args__[2]
assert getattr(third_arg, '__origin__', None) is list
assert third_arg.__args__[0] == str


def test_typed_signatures_from_dict():
# Creating a Signature directly from a dictionary with types
fields = {
"input_str_list": (List[str], InputField()),
"input_dict_int": (Dict[str, int], InputField()),
"output_tup": (Tuple[int, float], OutputField()),
}
sig = Signature(fields)
assert "input_str_list" in sig.input_fields
assert sig.input_fields["input_str_list"].annotation == List[str]
assert "input_dict_int" in sig.input_fields
assert sig.input_fields["input_dict_int"].annotation == Dict[str, int]
assert "output_tup" in sig.output_fields
assert sig.output_fields["output_tup"].annotation == Tuple[int, float]


def test_typed_signatures_complex_combinations():
# Test a very complex signature with multiple nested constructs
# input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]
sig = Signature("input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]")
input_complex_ann = sig.input_fields["input_complex"].annotation
assert getattr(input_complex_ann, '__origin__', None) is dict
key_arg, value_arg = input_complex_ann.__args__
assert key_arg == str
# value_arg: List[Optional[Tuple[int, str]]]
assert getattr(value_arg, '__origin__', None) is list
inner_union = value_arg.__args__[0]
# inner_union should be Optional[Tuple[int, str]]
# which is Union[Tuple[int, str], None]
assert getattr(inner_union, '__origin__', None) is Union
tuple_type = [t for t in inner_union.__args__ if t != type(None)][0]
assert getattr(tuple_type, '__origin__', None) is tuple
assert tuple_type.__args__ == (int, str)

output_complex_ann = sig.output_fields["output_complex"].annotation
assert getattr(output_complex_ann, '__origin__', None) is Union
assert len(output_complex_ann.__args__) == 2
possible_args = set(output_complex_ann.__args__)
# Expecting List[str] and Dict[str, Any]
# Because sets don't preserve order, just check membership.
# Find the List[str] arg
list_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is list)
dict_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is dict)
assert list_arg.__args__ == (str,)
k, v = dict_arg.__args__
assert k == str and v == Any
Loading