From 6ba413079f8c26a78372c086ff5068ff80a1c2ea Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sat, 2 Mar 2024 19:20:15 -0800 Subject: [PATCH 1/8] Some refinements to signature, allowing instantiation and fixing lint issues. --- dspy/signatures/signature.py | 256 +++++++++++++++-------------- pyproject.toml | 14 -- tests/signatures/test_signature.py | 40 +++-- 3 files changed, 160 insertions(+), 150 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 7cd2f170cb..a335f219f0 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -3,37 +3,42 @@ import dsp from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo -from typing import Type, Union, Dict, Tuple +from typing import Type, Union, Dict, Tuple # noqa: UP035 import re from dspy.signatures.field import InputField, OutputField, new_to_old_field -def signature_to_template(signature): - """Convert from new to legacy format""" +def signature_to_template(signature) -> dsp.Template: + """Convert from new to legacy format.""" return dsp.Template( signature.instructions, **{name: new_to_old_field(field) for name, field in signature.fields.items()}, ) -def _default_instructions(cls): - inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields.keys()]) - outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields.keys()]) +def _default_instructions(cls) -> str: + inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields]) + outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields]) return f"Given the fields {inputs_}, produce the fields {outputs_}." class SignatureMeta(type(BaseModel)): - def __new__(mcs, name, bases, namespace, **kwargs): + def __call__(cls, *args, **kwargs): # noqa: ANN002 + if cls is Signature: + return make_signature(*args, **kwargs) + return super().__call__(*args, **kwargs) + + def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804 # Set `str` as the default type for all fields raw_annotations = namespace.get("__annotations__", {}) - for name, field in namespace.items(): + for name, _field in namespace.items(): if not name.startswith("__") and name not in raw_annotations: raw_annotations[name] = str namespace["__annotations__"] = raw_annotations # Let Pydantic do its thing - cls = super().__new__(mcs, name, bases, namespace, **kwargs) + cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs) if cls.__doc__ is None: cls.__doc__ = _default_instructions(cls) @@ -69,17 +74,20 @@ def signature(cls) -> str: def instructions(cls) -> str: return getattr(cls, "__doc__", "") - def with_instructions(cls, instructions: str): + def with_instructions(cls, instructions: str) -> Type["Signature"]: return Signature(cls.fields, instructions) @property - def fields(cls): + def fields(cls) -> dict[str, FieldInfo]: # Make sure to give input fields before output fields return {**cls.input_fields, **cls.output_fields} - def with_updated_fields(cls, name, type_=None, **kwargs): - """Returns a new Signature type with the field, name, updated - with fields[name].json_schema_extra[key] = value.""" + def with_updated_fields(cls, name, type_=None, **kwargs) -> Type["Signature"]: + """Update the field, name, in a new Signature type. + + Returns a new Signature type with the field, name, updated + with fields[name].json_schema_extra[key] = value. + """ fields_copy = deepcopy(cls.fields) fields_copy[name].json_schema_extra = { **fields_copy[name].json_schema_extra, @@ -90,27 +98,23 @@ def with_updated_fields(cls, name, type_=None, **kwargs): return Signature(fields_copy, cls.instructions) @property - def input_fields(cls): + def input_fields(cls) -> dict[str, FieldInfo]: return cls._get_fields_with_type("input") @property - def output_fields(cls): + def output_fields(cls) -> dict[str, FieldInfo]: return cls._get_fields_with_type("output") - def _get_fields_with_type(cls, field_type): - return { - k: v - for k, v in cls.model_fields.items() - if v.json_schema_extra["__dspy_field_type"] == field_type - } + def _get_fields_with_type(cls, field_type) -> dict[str, FieldInfo]: + return {k: v for k, v in cls.model_fields.items() if v.json_schema_extra["__dspy_field_type"] == field_type} - def prepend(cls, name, field, type_=None): + def prepend(cls, name, field, type_=None) -> Type["Signature"]: return cls.insert(0, name, field, type_) - def append(cls, name, field, type_=None): + def append(cls, name, field, type_=None) -> Type["Signature"]: return cls.insert(-1, name, field, type_) - def insert(cls, index: int, name: str, field, type_: Type = None): + def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signature"]: # It's posisble to set the type as annotation=type in pydantic.Field(...) # But this may be annoying for users, so we allow them to pass the type if type_ is None: @@ -122,11 +126,7 @@ def insert(cls, index: int, name: str, field, type_: Type = None): output_fields = list(cls.output_fields.items()) # Choose the list to insert into based on the field type - lst = ( - input_fields - if field.json_schema_extra["__dspy_field_type"] == "input" - else output_fields - ) + lst = input_fields if field.json_schema_extra["__dspy_field_type"] == "input" else output_fields # We support negative insert indices if index < 0: index += len(lst) + 1 @@ -137,83 +137,7 @@ def insert(cls, index: int, name: str, field, type_: Type = None): new_fields = dict(input_fields + output_fields) return Signature(new_fields, cls.instructions) - def _parse_signature(cls, signature: str) -> Tuple[Type, Field]: - pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$" - if not re.match(pattern, signature): - raise ValueError(f"Invalid signature format: '{signature}'") - - fields = {} - inputs_str, outputs_str = map(str.strip, signature.split("->")) - inputs = [v.strip() for v in inputs_str.split(",") if v.strip()] - outputs = [v.strip() for v in outputs_str.split(",") if v.strip()] - for name in inputs: - fields[name] = (str, InputField()) - for name in outputs: - fields[name] = (str, OutputField()) - - return fields - - def __call__( - cls, - signature: Union[str, Dict[str, Tuple[type, FieldInfo]]], - instructions: str = None, - ): - """ - Creates a new Signature type with the given fields and instructions. - Note: - Even though we're calling a type, we're not making an instance of the type. - In general we don't allow instances of Signature types to be made. The call - syntax is only for your convenience. - Parameters: - signature: Format: "input1, input2 -> output1, output2" - instructions: Optional prompt for the signature. - """ - - if isinstance(signature, str): - fields = cls._parse_signature(signature) - else: - fields = signature - - # Validate the fields, this is important because we sometimes forget the - # slightly unintuitive syntax with tuples of (type, Field) - fixed_fields = {} - for name, type_field in fields.items(): - assert isinstance( - name, str, - ), f"Field names must be strings, not {type(name)}" - if isinstance(type_field, FieldInfo): - type_ = type_field.annotation - field = type_field - else: - assert isinstance( - type_field, tuple, - ), f"Field values must be tuples, not {type(type_field)}" - type_, field = type_field - # It might be better to be explicit about the type, but it currently would break - # program of thought and teleprompters, so we just silently default to string. - if type_ is None: - type_ = str - assert isinstance(type_, type) or isinstance( - typing.get_origin(type_), type, - ), f"Field types must be types, not {type(type_)}" - assert isinstance( - field, FieldInfo, - ), f"Field values must be Field instances, not {type(field)}" - fixed_fields[name] = (type_, field) - - # Fixing the fields shouldn't change the order - assert list(fixed_fields.keys()) == list(fields.keys()) - - # Default prompt when no instructions are provided - if instructions is None: - sig = Signature(signature, "") # Simple way to parse input/output fields - instructions = _default_instructions(sig) - - signature = create_model("Signature", __base__=Signature, **fixed_fields) - signature.__doc__ = instructions - return signature - - def equals(cls, other): + def equals(cls, other) -> bool: """Compare the JSON schema of two Pydantic models.""" if not isinstance(other, type) or not issubclass(other, BaseModel): return False @@ -226,30 +150,44 @@ def equals(cls, other): return True def __repr__(cls): - """ - Outputs something on the form: + """Output a representation of the signature. + + Uses the form: Signature(question, context -> answer question: str = InputField(desc="..."), context: List[str] = InputField(desc="..."), answer: int = OutputField(desc="..."), - ) + ). """ field_reprs = [] for name, field in cls.fields.items(): field_reprs.append(f"{name} = Field({field})") field_repr = "\n ".join(field_reprs) - return ( - f"Signature({cls.signature}\n" - f" instructions={repr(cls.instructions)}\n" - f" {field_repr}\n)" - ) + return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)" class Signature(BaseModel, metaclass=SignatureMeta): + """A signature for a predictor. + + You typically subclass it, like this: + class MySignature(Signature): + input: str = InputField(desc="...") + output: int = OutputField(desc="...") + + You can call Signature("input1, input2 -> output1, output2") to create a new signature type. + You can also include instructions, Signature("input -> output", "This is a test"). + But it's generally better to use the make_signature function. + + If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), + or a signature, you can use the ensure_signature function. + + For compatibility with the legacy dsp format, you can use the signature_to_template function. + """ + pass -def ensure_signature(signature): +def ensure_signature(signature: str | Type[Signature]) -> Signature: if signature is None: return None if isinstance(signature, str): @@ -257,19 +195,97 @@ def ensure_signature(signature): return signature -def infer_prefix(attribute_name: str) -> str: - """Infers a prefix from an attribute name.""" +def make_signature( + signature: Union[str, Dict[str, Tuple[type, FieldInfo]]], + instructions: str = None, + signature_name: str = "StringSignature", +) -> Type[Signature]: + """Create a new Signature type with the given fields and instructions. + + Note: + Even though we're calling a type, we're not making an instance of the type. + In general, instances of Signature types are not allowed to be made. The call + syntax is provided for convenience. + + Args: + signature: The signature format, specified as "input1, input2 -> output1, output2". + instructions: An optional prompt for the signature. + signature_name: An optional name for the new signature type. + """ + fields = _parse_signature(signature) if isinstance(signature, str) else signature + + # Validate the fields, this is important because we sometimes forget the + # slightly unintuitive syntax with tuples of (type, Field) + fixed_fields = {} + for name, type_field in fields.items(): + if not isinstance(name, str): + raise ValueError(f"Field names must be strings, not {type(name)}") + if isinstance(type_field, FieldInfo): + type_ = type_field.annotation + field = type_field + else: + if not isinstance(type_field, tuple): + raise ValueError(f"Field values must be tuples, not {type(type_field)}") + type_, field = type_field + # It might be better to be explicit about the type, but it currently would break + # program of thought and teleprompters, so we just silently default to string. + if type_ is None: + type_ = str + if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type): + raise ValueError(f"Field types must be types, not {type(type_)}") + if not isinstance(field, FieldInfo): + raise ValueError(f"Field values must be Field instances, not {type(field)}") + fixed_fields[name] = (type_, field) + + # Fixing the fields shouldn't change the order + assert list(fixed_fields.keys()) == list(fields.keys()) # noqa: S101 + + # Default prompt when no instructions are provided + if instructions is None: + sig = Signature(signature, "") # Simple way to parse input/output fields + instructions = _default_instructions(sig) + + return create_model( + signature_name, + __base__=Signature, + __doc__=instructions, + **fixed_fields, + ) + +def _parse_signature(signature: str) -> Tuple[Type, Field]: + pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$" + if not re.match(pattern, signature): + raise ValueError(f"Invalid signature format: '{signature}'") + + fields = {} + inputs_str, outputs_str = map(str.strip, signature.split("->")) + inputs = [v.strip() for v in inputs_str.split(",") if v.strip()] + outputs = [v.strip() for v in outputs_str.split(",") if v.strip()] + for name in inputs: + fields[name] = (str, InputField()) + for name in outputs: + fields[name] = (str, OutputField()) + + return fields + + +def infer_prefix(attribute_name: str) -> str: + """Infer a prefix from an attribute name.""" # Convert camelCase to snake_case, but handle sequences of capital letters properly s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", attribute_name) intermediate_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1) # Insert underscores around numbers to ensure spaces in the final output with_underscores_around_numbers = re.sub( - r"([a-zA-Z])(\d)", r"\1_\2", intermediate_name, + r"([a-zA-Z])(\d)", + r"\1_\2", + intermediate_name, ) with_underscores_around_numbers = re.sub( - r"(\d)([a-zA-Z])", r"\1_\2", with_underscores_around_numbers, + r"(\d)([a-zA-Z])", + r"\1_\2", + with_underscores_around_numbers, ) # Convert snake_case to 'Proper Title Case', but ensure acronyms are uppercased diff --git a/pyproject.toml b/pyproject.toml index bdb4bc75d9..7828de0cf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -232,28 +232,14 @@ ignore = [ "ANN003", # utf-8 encoding skip "UP009", - # First argument of a method should be named `self` - "N805", - # 1 blank line required between summary line and description - "D205", # Missing return type annotation for special method `__init__` "ANN204", - # Avoid using the generic variable name `df` for DataFrames - "PD901", - # Unnecessary assignment to `df` before `return` statement - "RET504", - # commented code - "ERA001", # Star-arg unpacking after a keyword argument is strongly discouraged "B026", # Missing type annotation for function argument `self` "ANN001", # Dynamically typed expressions (typing.Any) are disallowed in `wrapper` "ANN401", - # Unnecessary `elif` after `return` statement - "RET505", - # Within an `except` clause, raise exceptions with `raise - "B904", # We don't need docstrings for every method "ANN202", "D107", diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index b093258540..26fdfc7aa0 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -39,12 +39,8 @@ class TestSignature(Signature): input = InputField(prefix="Modified:") output = OutputField() - assert ( - TestSignature.input_fields["input"].json_schema_extra["prefix"] == "Modified:" - ) - assert ( - TestSignature.output_fields["output"].json_schema_extra["prefix"] == "Output:" - ) + assert TestSignature.input_fields["input"].json_schema_extra["prefix"] == "Modified:" + assert TestSignature.output_fields["output"].json_schema_extra["prefix"] == "Output:" def test_signature_parsing(): @@ -69,10 +65,7 @@ def test_with_updated_field(): assert signature1 is not signature2, "The type should be immutable" for key in signature1.fields.keys(): if key != "input1": - assert ( - signature1.fields[key].json_schema_extra - == signature2.fields[key].json_schema_extra - ) + assert signature1.fields[key].json_schema_extra == signature2.fields[key].json_schema_extra assert signature1.instructions == signature2.instructions @@ -97,18 +90,14 @@ def test_signature_instructions_none(): def test_signature_from_dict(): - signature = Signature( - {"input1": InputField(), "input2": InputField(), "output": OutputField()} - ) + signature = Signature({"input1": InputField(), "input2": InputField(), "output": OutputField()}) for k in ["input1", "input2", "output"]: assert k in signature.fields assert signature.fields[k].annotation == str def test_signature_from_dict(): - signature = Signature( - {"input1": InputField(), "input2": InputField(), "output": OutputField()} - ) + signature = Signature({"input1": InputField(), "input2": InputField(), "output": OutputField()}) assert "input1" in signature.input_fields assert "input2" in signature.input_fields assert "output" in signature.output_fields @@ -164,3 +153,22 @@ def test_infer_prefix(): assert infer_prefix("URLAddress") == "URL Address" assert infer_prefix("isHTTPSecure") == "Is HTTP Secure" assert infer_prefix("isHTTPSSecure123") == "Is HTTPS Secure 123" + + +def test_insantiating(): + sig = Signature("input -> output") + assert issubclass(sig, Signature) + assert sig.__name__ == "StringSignature" + value = sig(input="test", output="test") + assert isinstance(value, sig) + + +def test_insantiating2(): + class SubSignature(Signature): + input = InputField() + output = OutputField() + + assert issubclass(SubSignature, Signature) + assert SubSignature.__name__ == "SubSignature" + value = SubSignature(input="test", output="test") + assert isinstance(value, SubSignature) From 6af5366e9064918e784f68ed736626e08c53a317 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sat, 2 Mar 2024 19:23:47 -0800 Subject: [PATCH 2/8] Avoid shadowing by lint --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7828de0cf3..db7fc3fe4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,6 +218,8 @@ select = [ "ERA", # pandas-vet "PD", + # avoid shadowing + "PLW", ] ignore = [ "D100", From 9a1c189171b774f590e17254758cfe166294cb89 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sun, 3 Mar 2024 12:30:25 -0800 Subject: [PATCH 3/8] Support for n=... and type-strings --- dspy/functional/functional.py | 61 ++++++++++++++++------------- dspy/primitives/prediction.py | 20 +++++----- dspy/signatures/signature.py | 53 ++++++++++++++++++++++--- tests/functional/test_functional.py | 51 ++++++++++++++++++------ 4 files changed, 130 insertions(+), 55 deletions(-) diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 96c218ebbd..041ee7abc5 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -1,3 +1,4 @@ +from collections import defaultdict import inspect import os import openai @@ -7,6 +8,7 @@ from typing import Annotated, List, Tuple # noqa: UP035 from dsp.templates import passages2text import json +from dspy.primitives.prediction import Prediction from dspy.signatures.signature import ensure_signature @@ -71,7 +73,7 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802 class TypedPredictor(dspy.Module): def __init__(self, signature): super().__init__() - self.signature = signature + self.signature = ensure_signature(signature) self.predictor = dspy.Predict(signature) def copy(self) -> "TypedPredictor": @@ -127,8 +129,7 @@ def _prepare_signature(self) -> dspy.Signature: name, desc=field.json_schema_extra.get("desc", "") + ( - ". Respond with a single JSON object. JSON Schema: " - + json.dumps(type_.model_json_schema()) + ". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema()) ), format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)), parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)), @@ -152,28 +153,33 @@ def forward(self, **kwargs) -> dspy.Prediction: for try_i in range(MAX_RETRIES): result = self.predictor(**modified_kwargs, new_signature=signature) errors = {} - parsed_results = {} + parsed_results = defaultdict(list) # Parse the outputs for name, field in signature.output_fields.items(): - try: - value = getattr(result, name) - parser = field.json_schema_extra.get("parser", lambda x: x) - parsed_results[name] = parser(value) - except (pydantic.ValidationError, ValueError) as e: - errors[name] = _format_error(e) - # If we can, we add an example to the error message - current_desc = field.json_schema_extra.get("desc", "") - i = current_desc.find("JSON Schema: ") - if i == -1: - continue # Only add examples to JSON objects - suffix, current_desc = current_desc[i:], current_desc[:i] - prefix = "You MUST use this format: " - if try_i + 1 < MAX_RETRIES \ - and prefix not in current_desc \ - and (example := self._make_example(field.annotation)): - signature = signature.with_updated_fields( - name, desc=current_desc + "\n" + prefix + example + "\n" + suffix, - ) + for i, completion in enumerate(result.completions): + try: + value = completion[name] + parser = field.json_schema_extra.get("parser", lambda x: x) + completion[name] = parser(value) + parsed_results[name].append(parser(value)) + except (pydantic.ValidationError, ValueError) as e: + errors[name] = _format_error(e) + # If we can, we add an example to the error message + current_desc = field.json_schema_extra.get("desc", "") + i = current_desc.find("JSON Schema: ") + if i == -1: + continue # Only add examples to JSON objects + suffix, current_desc = current_desc[i:], current_desc[:i] + prefix = "You MUST use this format: " + if ( + try_i + 1 < MAX_RETRIES + and prefix not in current_desc + and (example := self._make_example(field.annotation)) + ): + signature = signature.with_updated_fields( + name, + desc=current_desc + "\n" + prefix + example + "\n" + suffix, + ) if errors: # Add new fields for each error for name, error in errors.items(): @@ -187,11 +193,12 @@ def forward(self, **kwargs) -> dspy.Prediction: ) else: # If there are no errors, we return the parsed results - for name, value in parsed_results.items(): - setattr(result, name, value) - return result + # for name, value in parsed_results.items(): + # setattr(result, name, value) + return Prediction.from_completions(parsed_results) raise ValueError( - "Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", errors, + "Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", + errors, ) diff --git a/dspy/primitives/prediction.py b/dspy/primitives/prediction.py index df653c1c4a..d77ee15a9d 100644 --- a/dspy/primitives/prediction.py +++ b/dspy/primitives/prediction.py @@ -4,12 +4,12 @@ class Prediction(Example): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + del self._demos del self._input_keys self._completions = None - + @classmethod def from_completions(cls, list_or_dict, signature=None): obj = cls() @@ -17,16 +17,16 @@ def from_completions(cls, list_or_dict, signature=None): obj._store = {k: v[0] for k, v in obj._completions.items()} return obj - + def __repr__(self): - store_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._store.items()) + store_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._store.items()) if self._completions is None or len(self._completions) == 1: return f"Prediction(\n {store_repr}\n)" - + num_completions = len(self._completions) return f"Prediction(\n {store_repr},\n completions=Completions(...)\n) ({num_completions-1} completions omitted)" - + def __str__(self): return self.__repr__() @@ -62,15 +62,15 @@ def __getitem__(self, key): if isinstance(key, int): if key < 0 or key >= len(self): raise IndexError("Index out of range") - + return Prediction(**{k: v[key] for k, v in self._completions.items()}) - + return self._completions[key] def __getattr__(self, name): if name in self._completions: return self._completions[name] - + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def __len__(self): @@ -82,7 +82,7 @@ def __contains__(self, key): return key in self._completions def __repr__(self): - items_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._completions.items()) + items_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._completions.items()) return f"Completions(\n {items_repr}\n)" def __str__(self): diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index a335f219f0..95455c3704 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,9 +1,10 @@ +import ast from copy import deepcopy import typing import dsp from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo -from typing import Type, Union, Dict, Tuple # noqa: UP035 +from typing import Any, Type, Union, Dict, Tuple # noqa: UP035 import re from dspy.signatures.field import InputField, OutputField, new_to_old_field @@ -254,7 +255,7 @@ def make_signature( def _parse_signature(signature: str) -> Tuple[Type, Field]: - pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$" + pattern = r"^\s*[\w\s,:]+\s*->\s*[\w\s,:]+\s*$" if not re.match(pattern, signature): raise ValueError(f"Invalid signature format: '{signature}'") @@ -262,14 +263,54 @@ def _parse_signature(signature: str) -> Tuple[Type, Field]: inputs_str, outputs_str = map(str.strip, signature.split("->")) inputs = [v.strip() for v in inputs_str.split(",") if v.strip()] outputs = [v.strip() for v in outputs_str.split(",") if v.strip()] - for name in inputs: - fields[name] = (str, InputField()) - for name in outputs: - fields[name] = (str, OutputField()) + for name_type in inputs: + name, type_ = _parse_named_type_node(name_type) + fields[name] = (type_, InputField()) + for name_type in outputs: + name, type_ = _parse_named_type_node(name_type) + fields[name] = (type_, OutputField()) return fields +def _parse_named_type_node(node, names=None) -> Any: + parts = node.split(":") + if len(parts) == 1: + return parts[0], str + name, type_str = parts + type_ = _parse_type_node(ast.parse(type_str), names) + return name, type_ + + +def _parse_type_node(node, names=None) -> Any: + """Recursively parse an AST node representing a type annotation. + + using structural pattern matching introduced in Python 3.10. + """ + if names is None: + names = {} + match node: + case ast.Module(body=body): + if len(body) != 1: + raise ValueError(f"Code is not syntactically valid: {node}") + return _parse_type_node(body[0], names) + case ast.Expr(value=value): + return _parse_type_node(value, names) + case ast.Name(id=id): + if id in names: + return names[id] + for type_ in [int, str, float, bool, list, tuple, dict]: + if type_.__name__ == id: + return type_ + case ast.Subscript(value=value, slice=slice): + base_type = _parse_type_node(value, names) + arg_type = _parse_type_node(slice, names) + return base_type[arg_type] + case ast.Tuple(elts=elts): + return tuple(_parse_type_node(elt, names) for elt in elts) + raise ValueError(f"Code is not syntactically valid: {node}") + + def infer_prefix(attribute_name: str) -> str: """Infer a prefix from an attribute name.""" # Convert camelCase to snake_case, but handle sequences of capital letters properly diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index f515f93266..029519bcfd 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -35,9 +35,7 @@ def hard_questions(topics: List[str]) -> List[str]: pass expected = ["What is the speed of light?", "What is the speed of sound?"] - lm = DummyLM( - ['{"value": ["What is the speed of light?", "What is the speed of sound?"]}'] - ) + lm = DummyLM(['{"value": ["What is the speed of light?", "What is the speed of sound?"]}']) dspy.settings.configure(lm=lm) question = hard_questions(topics=["Physics", "Music"]) @@ -88,9 +86,7 @@ def test_simple_class(): class Answer(pydantic.BaseModel): value: float certainty: float - comments: List[str] = pydantic.Field( - description="At least two comments about the answer" - ) + comments: List[str] = pydantic.Field(description="At least two comments about the answer") class QA(FunctionalModule): @predictor @@ -229,9 +225,7 @@ def simple_metric(example, prediction, trace=None): lm = DummyLM(["blue", "Ring-ding-ding-ding-dingeringeding!"], follow_examples=True) dspy.settings.configure(lm=lm, trace=[]) - bootstrap = BootstrapFewShot( - metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1 - ) + bootstrap = BootstrapFewShot(metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1) compiled_student = bootstrap.compile(student, teacher=teacher, trainset=trainset) lm.inspect_history(n=2) @@ -295,7 +289,7 @@ def flight_information(email: str) -> TravelInformation: # Example with a bad origin code. '{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}', # Example to help the model understand - '{...}', + "{...}", # Fixed '{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}', ] @@ -344,9 +338,9 @@ def flight_information(email: str) -> TravelInformation: [ # First origin is wrong, then destination, then all is good '{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}', - '{...}', # Example to help the model understand + "{...}", # Example to help the model understand '{"origin": "JFK", "destination": "LA0", "date": "2022-12-25"}', - '{...}', # Example to help the model understand + "{...}", # Example to help the model understand '{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}', ] ) @@ -447,3 +441,36 @@ def test(input: Annotated[str, Field(description="description")]) -> Annotated[f output = test(input="input") assert output == 0.5 + + +def test_multiple_outputs(): + lm = DummyLM([str(i) for i in range(100)]) + dspy.settings.configure(lm=lm) + + test = TypedPredictor("input -> output") + output = test(input="input", config=dict(n=3)).completions.output + assert output == ["0", "1", "2"] + + +def test_multiple_outputs_int(): + lm = DummyLM([str(i) for i in range(100)]) + dspy.settings.configure(lm=lm) + + class TestSignature(dspy.Signature): + input: int = dspy.InputField() + output: int = dspy.OutputField() + + test = TypedPredictor(TestSignature) + + output = test(input=8, config=dict(n=3)).completions.output + assert output == [0, 1, 2] + + +def test_parse_type_string(): + lm = DummyLM([str(i) for i in range(100)]) + dspy.settings.configure(lm=lm) + + test = TypedPredictor("input:int -> output:int") + + output = test(input=8, config=dict(n=3)).completions.output + assert output == [0, 1, 2] From a021bd26763a0c3efa8fffa84e8a8fcdf27554f8 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sun, 3 Mar 2024 12:30:35 -0800 Subject: [PATCH 4/8] Linting of dummy --- dspy/utils/dummies.py | 62 ++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index a0c997145b..3a37388073 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -1,6 +1,5 @@ import random from dsp.modules import LM -from typing import List, Union, Dict import numpy as np from dsp.utils.utils import dotdict import re @@ -9,9 +8,9 @@ class DummyLM(LM): """Dummy language model for unit testing purposes.""" - def __init__(self, answers: Union[List[str], Dict[str,str]], follow_examples: bool = False): - """ - Initializes the dummy language model. + def __init__(self, answers: list[str] | dict[str, str], follow_examples: bool = False): + """Initializes the dummy language model. + Parameters: - answers: A list of strings or a dictionary with string keys and values. - follow_examples: If True, and the prompt contains an example exactly equal to the prompt, @@ -24,7 +23,7 @@ def __init__(self, answers: Union[List[str], Dict[str,str]], follow_examples: bo self.answers = answers self.follow_examples = follow_examples - def basic_request(self, prompt, n=1, **kwargs): + def basic_request(self, prompt, n=1, **kwargs) -> dict[str, list[dict[str, str]]]: """Generates a dummy response based on the prompt.""" dummy_response = {"choices": []} for _ in range(n): @@ -55,12 +54,14 @@ def basic_request(self, prompt, n=1, **kwargs): answer = "No more responses" # Mimic the structure of a real language model response. - dummy_response["choices"].append({ - "text": answer, - "finish_reason": "simulated completion", - }) - - RED, GREEN, RESET = '\033[91m', '\033[92m', '\033[0m' + dummy_response["choices"].append( + { + "text": answer, + "finish_reason": "simulated completion", + }, + ) + + RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m" print("=== DummyLM ===") print(prompt, end="") print(f"{RED}{answer}{RESET}") @@ -77,67 +78,68 @@ def basic_request(self, prompt, n=1, **kwargs): return dummy_response - def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): + def __call__(self, prompt, _only_completed=True, _return_sorted=False, **kwargs): """Retrieves dummy completions.""" response = self.basic_request(prompt, **kwargs) choices = response["choices"] # Filter choices and return text completions. - completions = [choice["text"] for choice in choices] - - return completions + return [choice["text"] for choice in choices] - def get_convo(self, index): - """Get the prompt + anwer from the ith message""" - return self.history[index]['prompt'] \ - + " " \ - + self.history[index]['response']['choices'][0]['text'] + def get_convo(self, index) -> str: + """Get the prompt + anwer from the ith message.""" + return self.history[index]["prompt"] + " " + self.history[index]["response"]["choices"][0]["text"] -def dummy_rm(passages=()): +def dummy_rm(passages=()) -> callable: if not passages: - def inner(query:str, *, k:int, **kwargs): + + def inner(query: str, *, k: int, **kwargs): assert False, "No passages defined" + return inner max_length = max(map(len, passages)) + 100 vectorizer = DummyVectorizer(max_length) passage_vecs = vectorizer(passages) - def inner(query:str, *, k:int, **kwargs): + + def inner(query: str, *, k: int, **kwargs): assert k <= len(passages) query_vec = vectorizer([query])[0] scores = passage_vecs @ query_vec largest_idx = (-scores).argsort()[:k] - #return dspy.Prediction(passages=[passages[i] for i in largest_idx]) + # return dspy.Prediction(passages=[passages[i] for i in largest_idx]) return [dotdict(dict(long_text=passages[i])) for i in largest_idx] + return inner class DummyVectorizer: - """Simple vectorizer based on n-grams""" + """Simple vectorizer based on n-grams.""" + def __init__(self, max_length=100, n_gram=2): self.max_length = max_length self.n_gram = n_gram self.P = 10**9 + 7 # A large prime number random.seed(123) self.coeffs = [random.randrange(1, self.P) for _ in range(n_gram)] - + def _hash(self, gram): - """Hashes a string using a polynomial hash function""" + """Hashes a string using a polynomial hash function.""" h = 1 for coeff, c in zip(self.coeffs, gram): h = h * coeff + ord(c) h %= self.P return h % self.max_length - def __call__(self, texts: List[str]) -> np.ndarray: + def __call__(self, texts: list[str]) -> np.ndarray: vecs = [] for text in texts: - grams = [text[i:i+self.n_gram] for i in range(len(text) - self.n_gram + 1)] + grams = [text[i : i + self.n_gram] for i in range(len(text) - self.n_gram + 1)] vec = [0] * self.max_length for gram in grams: vec[self._hash(gram)] += 1 vecs.append(vec) - + vecs = np.array(vecs, dtype=np.float32) vecs -= np.mean(vecs, axis=1, keepdims=True) vecs /= np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-10 # Added epsilon to avoid division by zero From f223579a7e124c1e8b5ad27910f6a05190cd6433 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sun, 3 Mar 2024 13:08:30 -0800 Subject: [PATCH 5/8] Rewrote type parsing to be 3.9 compatible --- dspy/signatures/signature.py | 49 +++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 95455c3704..6493a8ce19 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -285,29 +285,38 @@ def _parse_named_type_node(node, names=None) -> Any: def _parse_type_node(node, names=None) -> Any: """Recursively parse an AST node representing a type annotation. - using structural pattern matching introduced in Python 3.10. + without using structural pattern matching introduced in Python 3.10. """ if names is None: names = {} - match node: - case ast.Module(body=body): - if len(body) != 1: - raise ValueError(f"Code is not syntactically valid: {node}") - return _parse_type_node(body[0], names) - case ast.Expr(value=value): - return _parse_type_node(value, names) - case ast.Name(id=id): - if id in names: - return names[id] - for type_ in [int, str, float, bool, list, tuple, dict]: - if type_.__name__ == id: - return type_ - case ast.Subscript(value=value, slice=slice): - base_type = _parse_type_node(value, names) - arg_type = _parse_type_node(slice, names) - return base_type[arg_type] - case ast.Tuple(elts=elts): - return tuple(_parse_type_node(elt, names) for elt in elts) + + if isinstance(node, ast.Module): + body = node.body + if len(body) != 1: + raise ValueError(f"Code is not syntactically valid: {node}") + return _parse_type_node(body[0], names) + + if isinstance(node, ast.Expr): + value = node.value + return _parse_type_node(value, names) + + if isinstance(node, ast.Name): + id_ = node.id + if id_ in names: + return names[id_] + for type_ in [int, str, float, bool, list, tuple, dict]: + if type_.__name__ == id_: + return type_ + + elif isinstance(node, ast.Subscript): + base_type = _parse_type_node(node.value, names) + arg_type = _parse_type_node(node.slice, names) + return base_type[arg_type] + + elif isinstance(node, ast.Tuple): + elts = node.elts + return tuple(_parse_type_node(elt, names) for elt in elts) + raise ValueError(f"Code is not syntactically valid: {node}") From dab44cd1fa5b18c8e46f595e889cebfa8ccaf461 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sun, 3 Mar 2024 13:14:39 -0800 Subject: [PATCH 6/8] Python 3.9 compatability --- dspy/signatures/signature.py | 2 +- dspy/utils/dummies.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 6493a8ce19..8c1f161642 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -188,7 +188,7 @@ class MySignature(Signature): pass -def ensure_signature(signature: str | Type[Signature]) -> Signature: +def ensure_signature(signature: Union[str, Type[Signature]]) -> Signature: if signature is None: return None if isinstance(signature, str): diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index 3a37388073..1f6390536a 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -1,4 +1,5 @@ import random +from typing import Union from dsp.modules import LM import numpy as np from dsp.utils.utils import dotdict @@ -8,7 +9,7 @@ class DummyLM(LM): """Dummy language model for unit testing purposes.""" - def __init__(self, answers: list[str] | dict[str, str], follow_examples: bool = False): + def __init__(self, answers: Union[list[str], dict[str, str]], follow_examples: bool = False): """Initializes the dummy language model. Parameters: From fe321da85da1bc290a0bcd2e6c277d2d7e68a91c Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sun, 3 Mar 2024 13:28:51 -0800 Subject: [PATCH 7/8] Validate main signature fields --- dspy/functional/functional.py | 48 ++++++++++++++++------------- tests/functional/test_functional.py | 17 ++++++++++ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 041ee7abc5..aa82288295 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -155,31 +155,37 @@ def forward(self, **kwargs) -> dspy.Prediction: errors = {} parsed_results = defaultdict(list) # Parse the outputs - for name, field in signature.output_fields.items(): - for i, completion in enumerate(result.completions): - try: + for i, completion in enumerate(result.completions): + try: + for name, field in signature.output_fields.items(): value = completion[name] parser = field.json_schema_extra.get("parser", lambda x: x) completion[name] = parser(value) parsed_results[name].append(parser(value)) - except (pydantic.ValidationError, ValueError) as e: - errors[name] = _format_error(e) - # If we can, we add an example to the error message - current_desc = field.json_schema_extra.get("desc", "") - i = current_desc.find("JSON Schema: ") - if i == -1: - continue # Only add examples to JSON objects - suffix, current_desc = current_desc[i:], current_desc[:i] - prefix = "You MUST use this format: " - if ( - try_i + 1 < MAX_RETRIES - and prefix not in current_desc - and (example := self._make_example(field.annotation)) - ): - signature = signature.with_updated_fields( - name, - desc=current_desc + "\n" + prefix + example + "\n" + suffix, - ) + # Instantiate the actual signature with the parsed values. + # This allow pydantic to validate the fields defined in the signature. + _dummy = self.signature( + **kwargs, + **{key: value[i] for key, value in parsed_results.items()}, + ) + except (pydantic.ValidationError, ValueError) as e: + errors[name] = _format_error(e) + # If we can, we add an example to the error message + current_desc = field.json_schema_extra.get("desc", "") + i = current_desc.find("JSON Schema: ") + if i == -1: + continue # Only add examples to JSON objects + suffix, current_desc = current_desc[i:], current_desc[:i] + prefix = "You MUST use this format: " + if ( + try_i + 1 < MAX_RETRIES + and prefix not in current_desc + and (example := self._make_example(field.annotation)) + ): + signature = signature.with_updated_fields( + name, + desc=current_desc + "\n" + prefix + example + "\n" + suffix, + ) if errors: # Add new fields for each error for name, error in errors.items(): diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 029519bcfd..e012324f74 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -474,3 +474,20 @@ def test_parse_type_string(): output = test(input=8, config=dict(n=3)).completions.output assert output == [0, 1, 2] + + +def test_fields_on_base_signature(): + class SimpleOutput(dspy.Signature): + output: float = dspy.OutputField(gt=0, lt=1) + + lm = DummyLM( + [ + "2.1", # Bad output + "0.5", # Good output + ] + ) + dspy.settings.configure(lm=lm) + + predictor = TypedPredictor(SimpleOutput) + + assert predictor().output == 0.5 From 16be7a1bc84decdb417e2f0e4d337656db34b995 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Sun, 3 Mar 2024 14:12:52 -0800 Subject: [PATCH 8/8] Fixes to completions --- dspy/functional/functional.py | 21 +++++++-------- tests/functional/test_functional.py | 41 +++++++++++++++++++++++++++++ tests/signatures/test_signature.py | 2 ++ 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index aa82288295..cb7b3669aa 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -10,7 +10,7 @@ import json from dspy.primitives.prediction import Prediction -from dspy.signatures.signature import ensure_signature +from dspy.signatures.signature import ensure_signature, make_signature MAX_RETRIES = 3 @@ -83,7 +83,7 @@ def copy(self) -> "TypedPredictor": def _make_example(type_) -> str: # Note: DSPy will cache this call so we only pay the first time TypedPredictor is called. json_object = dspy.Predict( - dspy.Signature( + make_signature( "json_schema -> json_object", "Make a very succinct json object that validates with the following schema", ), @@ -153,21 +153,20 @@ def forward(self, **kwargs) -> dspy.Prediction: for try_i in range(MAX_RETRIES): result = self.predictor(**modified_kwargs, new_signature=signature) errors = {} - parsed_results = defaultdict(list) + parsed_results = [] # Parse the outputs for i, completion in enumerate(result.completions): try: + parsed = {} for name, field in signature.output_fields.items(): value = completion[name] parser = field.json_schema_extra.get("parser", lambda x: x) completion[name] = parser(value) - parsed_results[name].append(parser(value)) + parsed[name] = parser(value) # Instantiate the actual signature with the parsed values. # This allow pydantic to validate the fields defined in the signature. - _dummy = self.signature( - **kwargs, - **{key: value[i] for key, value in parsed_results.items()}, - ) + _dummy = self.signature(**kwargs, **parsed) + parsed_results.append(parsed) except (pydantic.ValidationError, ValueError) as e: errors[name] = _format_error(e) # If we can, we add an example to the error message @@ -199,9 +198,9 @@ def forward(self, **kwargs) -> dspy.Prediction: ) else: # If there are no errors, we return the parsed results - # for name, value in parsed_results.items(): - # setattr(result, name, value) - return Prediction.from_completions(parsed_results) + return Prediction.from_completions( + {key: [r[key] for r in parsed_results] for key in signature.output_fields} + ) raise ValueError( "Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", errors, diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index e012324f74..cabbc7cd09 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -11,6 +11,7 @@ from dspy.functional import predictor, cot, FunctionalModule, TypedPredictor, functional from dspy.primitives.example import Example from dspy.teleprompt.bootstrap import BootstrapFewShot +from dspy.teleprompt.vanilla import LabeledFewShot from dspy.utils.dummies import DummyLM @@ -491,3 +492,43 @@ class SimpleOutput(dspy.Signature): predictor = TypedPredictor(SimpleOutput) assert predictor().output == 0.5 + + +def test_synthetic_data_gen(): + class SyntheticFact(BaseModel): + fact: str = Field(..., description="a statement") + varacity: bool = Field(..., description="is the statement true or false") + + class ExampleSignature(dspy.Signature): + """Generate an example of a synthetic fact.""" + + fact: SyntheticFact = dspy.OutputField() + + lm = DummyLM( + [ + '{"fact": "The sky is blue", "varacity": true}', + '{"fact": "The sky is green", "varacity": false}', + '{"fact": "The sky is red", "varacity": true}', + '{"fact": "The earth is flat", "varacity": false}', + '{"fact": "The earth is round", "varacity": true}', + '{"fact": "The earth is a cube", "varacity": false}', + ] + ) + dspy.settings.configure(lm=lm) + + generator = TypedPredictor(ExampleSignature) + examples = generator(config=dict(n=3)) + for ex in examples.completions.fact: + assert isinstance(ex, SyntheticFact) + assert examples.completions.fact[0] == SyntheticFact(fact="The sky is blue", varacity=True) + + # If you have examples and want more + existing_examples = [ + dspy.Example(fact="The sky is blue", varacity=True), + dspy.Example(fact="The sky is green", varacity=False), + ] + trained = LabeledFewShot().compile(student=generator, trainset=existing_examples) + + augmented_examples = trained(config=dict(n=3)) + for ex in augmented_examples.completions.fact: + assert isinstance(ex, SyntheticFact) diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index 26fdfc7aa0..554d9b1274 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -82,6 +82,8 @@ def test_instructions_signature(): def test_signature_instructions(): sig1 = Signature("input1 -> output1", instructions="This is a test") assert sig1.instructions == "This is a test" + sig2 = Signature("input1 -> output1", "This is a test") + assert sig2.instructions == "This is a test" def test_signature_instructions_none():