From b87fd5c9543ac1017bb4c5fdff420222a2d477ed Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sat, 7 Dec 2024 21:47:56 +0100 Subject: [PATCH 1/5] Added support for more types --- dspy/signatures/signature.py | 114 +++++++++++++++++++------- tests/signatures/test_signature.py | 124 ++++++++++++++++++++++++++++- 2 files changed, 208 insertions(+), 30 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 387b9d771d..f1b84e5d39 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -7,6 +7,7 @@ from contextlib import ExitStack, contextmanager from copy import deepcopy from typing import Any, Dict, Tuple, Type, Union # noqa: UP035 +import importlib from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo @@ -400,53 +401,108 @@ def _parse_arg_string(string: str, names=None) -> Dict[str, str]: def _parse_type_node(node, names=None) -> Any: - """Recursively parse an AST node representing a type annotation. - - without using structural pattern matching introduced in Python 3.10. - """ + """Recursively parse an AST node representing a type annotation.""" if names is None: - names = typing.__dict__ + from typing import Any, Union, Optional + names = dict(typing.__dict__) + from types import NoneType + names['NoneType'] = NoneType + + def resolve_name(id_: str): + # Check if it's a built-in known type or in the provided names + if id_ in names: + return names[id_] + + # Common built-in types + builtin_types = [int, str, float, bool, list, tuple, dict, set, frozenset, complex, bytes, bytearray] + + # Try PIL Image if 'Image' encountered + if 'Image' not in names: + try: + from PIL import Image + names['Image'] = Image + except ImportError: + pass + + # If we have PIL Image and id_ is 'Image', return it + if 'Image' in names and id_ == 'Image': + return names['Image'] + + # Check if it matches any known built-in type by name + for t in builtin_types: + if t.__name__ == id_: + return t + + # Attempt to import a module with this name dynamically + # This allows handling of module-based annotations like `dspy.Image`. + try: + mod = importlib.import_module(id_) + names[id_] = mod + return mod + except ImportError: + pass + + # If we don't know the type or module, raise an error + raise ValueError(f"Unknown name: {id_}") if isinstance(node, ast.Module): - body = node.body - if len(body) != 1: - raise ValueError(f"Code is not syntactically valid: {node}") - return _parse_type_node(body[0], names) + if len(node.body) != 1: + raise ValueError(f"Code is not syntactically valid: {ast.dump(node)}") + return _parse_type_node(node.body[0], names) if isinstance(node, ast.Expr): - value = node.value - return _parse_type_node(value, names) + return _parse_type_node(node.value, names) if isinstance(node, ast.Name): - id_ = node.id - if id_ in names: - return names[id_] + return resolve_name(node.id) - for type_ in [int, str, float, bool, list, tuple, dict, Image]: - if type_.__name__ == id_: - return type_ - raise ValueError(f"Unknown name: {id_}") + if isinstance(node, ast.Attribute): + base = _parse_type_node(node.value, names) + attr_name = node.attr + if hasattr(base, attr_name): + return getattr(base, attr_name) + else: + raise ValueError(f"Unknown attribute: {attr_name} on {base}") if isinstance(node, ast.Subscript): base_type = _parse_type_node(node.value, names) - arg_type = _parse_type_node(node.slice, names) - return base_type[arg_type] + slice_node = node.slice + if isinstance(slice_node, ast.Index): # For older Python versions + slice_node = slice_node.value + + if isinstance(slice_node, ast.Tuple): + arg_types = tuple(_parse_type_node(elt, names) for elt in slice_node.elts) + else: + arg_types = (_parse_type_node(slice_node, names),) + + # Special handling for Union, Optional + if base_type is typing.Union: + return typing.Union[arg_types] + if base_type is typing.Optional: + if len(arg_types) != 1: + raise ValueError("Optional must have exactly one type argument") + return typing.Optional[arg_types[0]] + + return base_type[arg_types] if isinstance(node, ast.Tuple): - elts = node.elts - return tuple(_parse_type_node(elt, names) for elt in elts) + return tuple(_parse_type_node(elt, names) for elt in node.elts) + + if isinstance(node, ast.Constant): + return node.value - if isinstance(node, ast.Call) and node.func.id == "Field": + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "Field": keys = [kw.arg for kw in node.keywords] - values = [kw.value.value for kw in node.keywords] + values = [] + for kw in node.keywords: + if isinstance(kw.value, ast.Constant): + values.append(kw.value.value) + else: + values.append(_parse_type_node(kw.value, names)) return Field(**dict(zip(keys, values))) - if isinstance(node, ast.Attribute) and node.attr == "Image": - return Image - - raise ValueError(f"Code is not syntactically valid: {node}") - + raise ValueError(f"Unhandled AST node type in annotation: {ast.dump(node)}") def infer_prefix(attribute_name: str) -> str: """Infer a prefix from an attribute name.""" diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index bc63db1ac3..8aeebdc34c 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -1,5 +1,5 @@ import textwrap -from typing import List +from typing import Any, Dict, List, Optional, Tuple, Union import pydantic import pytest @@ -279,3 +279,125 @@ class CustomSignature2(dspy.Signature): assert CustomSignature2.instructions == "I am a malicious instruction." assert CustomSignature2.fields["sentence"].json_schema_extra["desc"] == "I am an malicious input!" assert CustomSignature2.fields["sentiment"].json_schema_extra["prefix"] == "Sentiment:" + + +def test_typed_signatures_basic_types(): + # Simple built-in types + sig = Signature("input1: int, input2: str -> output: float") + assert "input1" in sig.input_fields + assert sig.input_fields["input1"].annotation == int + assert "input2" in sig.input_fields + assert sig.input_fields["input2"].annotation == str + assert "output" in sig.output_fields + assert sig.output_fields["output"].annotation == float + + +def test_typed_signatures_generics(): + # More complex generic types + sig = Signature("input_list: List[int], input_dict: Dict[str, float] -> output_tuple: Tuple[str, int]") + assert "input_list" in sig.input_fields + assert sig.input_fields["input_list"].annotation == List[int] + assert "input_dict" in sig.input_fields + assert sig.input_fields["input_dict"].annotation == Dict[str, float] + assert "output_tuple" in sig.output_fields + assert sig.output_fields["output_tuple"].annotation == Tuple[str, int] + + +def test_typed_signatures_unions_and_optionals(): + sig = Signature("input_opt: Optional[str], input_union: Union[int, None] -> output_union: Union[int, str]") + assert "input_opt" in sig.input_fields + # Optional[str] is actually Union[str, None] + # Depending on the environment, it might resolve to Union[str, None] or Optional[str], either is correct. + # We'll just check for a Union containing str and NoneType: + input_opt_annotation = sig.input_fields["input_opt"].annotation + assert (input_opt_annotation == Optional[str] or + (getattr(input_opt_annotation, '__origin__', None) is Union and str in input_opt_annotation.__args__ and type(None) in input_opt_annotation.__args__)) + + assert "input_union" in sig.input_fields + input_union_annotation = sig.input_fields["input_union"].annotation + assert (getattr(input_union_annotation, '__origin__', None) is Union and + int in input_union_annotation.__args__ and type(None) in input_union_annotation.__args__) + + assert "output_union" in sig.output_fields + output_union_annotation = sig.output_fields["output_union"].annotation + assert (getattr(output_union_annotation, '__origin__', None) is Union and + int in output_union_annotation.__args__ and str in output_union_annotation.__args__) + + +def test_typed_signatures_any(): + sig = Signature("input_any: Any -> output_any: Any") + assert "input_any" in sig.input_fields + assert sig.input_fields["input_any"].annotation == Any + assert "output_any" in sig.output_fields + assert sig.output_fields["output_any"].annotation == Any + + +def test_typed_signatures_nested(): + # Nested generics and unions + sig = Signature("input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]") + input_nested_ann = sig.input_fields["input_nested"].annotation + assert getattr(input_nested_ann, '__origin__', None) is list + assert len(input_nested_ann.__args__) == 1 + union_arg = input_nested_ann.__args__[0] + assert getattr(union_arg, '__origin__', None) is Union + assert str in union_arg.__args__ and int in union_arg.__args__ + + output_nested_ann = sig.output_fields["output_nested"].annotation + assert getattr(output_nested_ann, '__origin__', None) is tuple + assert output_nested_ann.__args__[0] == int + # The second arg is Optional[float], which is Union[float, None] + second_arg = output_nested_ann.__args__[1] + assert getattr(second_arg, '__origin__', None) is Union + assert float in second_arg.__args__ and type(None) in second_arg.__args__ + # The third arg is List[str] + third_arg = output_nested_ann.__args__[2] + assert getattr(third_arg, '__origin__', None) is list + assert third_arg.__args__[0] == str + + +def test_typed_signatures_from_dict(): + # Creating a Signature directly from a dictionary with types + fields = { + "input_str_list": (List[str], InputField()), + "input_dict_int": (Dict[str, int], InputField()), + "output_tup": (Tuple[int, float], OutputField()), + } + sig = Signature(fields) + assert "input_str_list" in sig.input_fields + assert sig.input_fields["input_str_list"].annotation == List[str] + assert "input_dict_int" in sig.input_fields + assert sig.input_fields["input_dict_int"].annotation == Dict[str, int] + assert "output_tup" in sig.output_fields + assert sig.output_fields["output_tup"].annotation == Tuple[int, float] + + +def test_typed_signatures_complex_combinations(): + # Test a very complex signature with multiple nested constructs + # input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]] + sig = Signature("input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]") + input_complex_ann = sig.input_fields["input_complex"].annotation + assert getattr(input_complex_ann, '__origin__', None) is dict + key_arg, value_arg = input_complex_ann.__args__ + assert key_arg == str + # value_arg: List[Optional[Tuple[int, str]]] + assert getattr(value_arg, '__origin__', None) is list + inner_union = value_arg.__args__[0] + # inner_union should be Optional[Tuple[int, str]] + # which is Union[Tuple[int, str], None] + assert getattr(inner_union, '__origin__', None) is Union + tuple_type = [t for t in inner_union.__args__ if t != type(None)][0] + assert getattr(tuple_type, '__origin__', None) is tuple + assert tuple_type.__args__ == (int, str) + + output_complex_ann = sig.output_fields["output_complex"].annotation + assert getattr(output_complex_ann, '__origin__', None) is Union + assert len(output_complex_ann.__args__) == 2 + possible_args = set(output_complex_ann.__args__) + # Expecting List[str] and Dict[str, Any] + # Because sets don't preserve order, just check membership. + # Find the List[str] arg + list_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is list) + dict_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is dict) + assert list_arg.__args__ == (str,) + k, v = dict_arg.__args__ + assert k == str and v == Any From 77385908f07a3690e1fe22e8feb5828d2631f46b Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sat, 7 Dec 2024 21:53:10 +0100 Subject: [PATCH 2/5] ruff --- dspy/signatures/signature.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index f1b84e5d39..ed493cd0cc 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -13,7 +13,6 @@ from pydantic.fields import FieldInfo import dsp -from dspy.adapters.image_utils import Image from dspy.signatures.field import InputField, OutputField, new_to_old_field @@ -404,7 +403,6 @@ def _parse_type_node(node, names=None) -> Any: """Recursively parse an AST node representing a type annotation.""" if names is None: - from typing import Any, Union, Optional names = dict(typing.__dict__) from types import NoneType names['NoneType'] = NoneType From 7ba3d79dbae23ef29995d94b9ec764eabc160c52 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Mon, 9 Dec 2024 21:31:57 +0100 Subject: [PATCH 3/5] python3.9 support --- dspy/signatures/signature.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index ed493cd0cc..4afddc0619 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -355,7 +355,7 @@ def make_signature( if type_ is None: type_ = str # if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type): - if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias)): + if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias, typing._SpecialForm)): raise ValueError(f"Field types must be types, not {type(type_)}") if not isinstance(field, FieldInfo): raise ValueError(f"Field values must be Field instances, not {type(field)}") @@ -404,8 +404,7 @@ def _parse_type_node(node, names=None) -> Any: if names is None: names = dict(typing.__dict__) - from types import NoneType - names['NoneType'] = NoneType + names['NoneType'] = type(None) def resolve_name(id_: str): # Check if it's a built-in known type or in the provided names From e0101d1bf55ba6e76d51a6128dfe661293fb183a Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Mon, 9 Dec 2024 21:42:09 +0100 Subject: [PATCH 4/5] Missing Image import --- dspy/signatures/signature.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 4afddc0619..0ef752e47b 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -13,6 +13,7 @@ from pydantic.fields import FieldInfo import dsp +from dspy.adapters.image_utils import Image from dspy.signatures.field import InputField, OutputField, new_to_old_field From 3c3dcda73ef7624213fb5e8109a24567bf3e8bda Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Mon, 9 Dec 2024 21:47:03 +0100 Subject: [PATCH 5/5] add ruff ignore comment --- dspy/signatures/signature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 0ef752e47b..56ca996563 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -13,7 +13,7 @@ from pydantic.fields import FieldInfo import dsp -from dspy.adapters.image_utils import Image +from dspy.adapters.image_utils import Image # noqa: F401 from dspy.signatures.field import InputField, OutputField, new_to_old_field