diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 96c218ebbd..cb7b3669aa 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -1,3 +1,4 @@ +from collections import defaultdict import inspect import os import openai @@ -7,8 +8,9 @@ from typing import Annotated, List, Tuple # noqa: UP035 from dsp.templates import passages2text import json +from dspy.primitives.prediction import Prediction -from dspy.signatures.signature import ensure_signature +from dspy.signatures.signature import ensure_signature, make_signature MAX_RETRIES = 3 @@ -71,7 +73,7 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802 class TypedPredictor(dspy.Module): def __init__(self, signature): super().__init__() - self.signature = signature + self.signature = ensure_signature(signature) self.predictor = dspy.Predict(signature) def copy(self) -> "TypedPredictor": @@ -81,7 +83,7 @@ def copy(self) -> "TypedPredictor": def _make_example(type_) -> str: # Note: DSPy will cache this call so we only pay the first time TypedPredictor is called. json_object = dspy.Predict( - dspy.Signature( + make_signature( "json_schema -> json_object", "Make a very succinct json object that validates with the following schema", ), @@ -127,8 +129,7 @@ def _prepare_signature(self) -> dspy.Signature: name, desc=field.json_schema_extra.get("desc", "") + ( - ". Respond with a single JSON object. JSON Schema: " - + json.dumps(type_.model_json_schema()) + ". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema()) ), format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)), parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)), @@ -152,13 +153,20 @@ def forward(self, **kwargs) -> dspy.Prediction: for try_i in range(MAX_RETRIES): result = self.predictor(**modified_kwargs, new_signature=signature) errors = {} - parsed_results = {} + parsed_results = [] # Parse the outputs - for name, field in signature.output_fields.items(): + for i, completion in enumerate(result.completions): try: - value = getattr(result, name) - parser = field.json_schema_extra.get("parser", lambda x: x) - parsed_results[name] = parser(value) + parsed = {} + for name, field in signature.output_fields.items(): + value = completion[name] + parser = field.json_schema_extra.get("parser", lambda x: x) + completion[name] = parser(value) + parsed[name] = parser(value) + # Instantiate the actual signature with the parsed values. + # This allow pydantic to validate the fields defined in the signature. + _dummy = self.signature(**kwargs, **parsed) + parsed_results.append(parsed) except (pydantic.ValidationError, ValueError) as e: errors[name] = _format_error(e) # If we can, we add an example to the error message @@ -168,11 +176,14 @@ def forward(self, **kwargs) -> dspy.Prediction: continue # Only add examples to JSON objects suffix, current_desc = current_desc[i:], current_desc[:i] prefix = "You MUST use this format: " - if try_i + 1 < MAX_RETRIES \ - and prefix not in current_desc \ - and (example := self._make_example(field.annotation)): + if ( + try_i + 1 < MAX_RETRIES + and prefix not in current_desc + and (example := self._make_example(field.annotation)) + ): signature = signature.with_updated_fields( - name, desc=current_desc + "\n" + prefix + example + "\n" + suffix, + name, + desc=current_desc + "\n" + prefix + example + "\n" + suffix, ) if errors: # Add new fields for each error @@ -187,11 +198,12 @@ def forward(self, **kwargs) -> dspy.Prediction: ) else: # If there are no errors, we return the parsed results - for name, value in parsed_results.items(): - setattr(result, name, value) - return result + return Prediction.from_completions( + {key: [r[key] for r in parsed_results] for key in signature.output_fields} + ) raise ValueError( - "Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", errors, + "Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", + errors, ) diff --git a/dspy/primitives/prediction.py b/dspy/primitives/prediction.py index df653c1c4a..d77ee15a9d 100644 --- a/dspy/primitives/prediction.py +++ b/dspy/primitives/prediction.py @@ -4,12 +4,12 @@ class Prediction(Example): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + del self._demos del self._input_keys self._completions = None - + @classmethod def from_completions(cls, list_or_dict, signature=None): obj = cls() @@ -17,16 +17,16 @@ def from_completions(cls, list_or_dict, signature=None): obj._store = {k: v[0] for k, v in obj._completions.items()} return obj - + def __repr__(self): - store_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._store.items()) + store_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._store.items()) if self._completions is None or len(self._completions) == 1: return f"Prediction(\n {store_repr}\n)" - + num_completions = len(self._completions) return f"Prediction(\n {store_repr},\n completions=Completions(...)\n) ({num_completions-1} completions omitted)" - + def __str__(self): return self.__repr__() @@ -62,15 +62,15 @@ def __getitem__(self, key): if isinstance(key, int): if key < 0 or key >= len(self): raise IndexError("Index out of range") - + return Prediction(**{k: v[key] for k, v in self._completions.items()}) - + return self._completions[key] def __getattr__(self, name): if name in self._completions: return self._completions[name] - + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def __len__(self): @@ -82,7 +82,7 @@ def __contains__(self, key): return key in self._completions def __repr__(self): - items_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._completions.items()) + items_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._completions.items()) return f"Completions(\n {items_repr}\n)" def __str__(self): diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 7cd2f170cb..8c1f161642 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,39 +1,45 @@ +import ast from copy import deepcopy import typing import dsp from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo -from typing import Type, Union, Dict, Tuple +from typing import Any, Type, Union, Dict, Tuple # noqa: UP035 import re from dspy.signatures.field import InputField, OutputField, new_to_old_field -def signature_to_template(signature): - """Convert from new to legacy format""" +def signature_to_template(signature) -> dsp.Template: + """Convert from new to legacy format.""" return dsp.Template( signature.instructions, **{name: new_to_old_field(field) for name, field in signature.fields.items()}, ) -def _default_instructions(cls): - inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields.keys()]) - outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields.keys()]) +def _default_instructions(cls) -> str: + inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields]) + outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields]) return f"Given the fields {inputs_}, produce the fields {outputs_}." class SignatureMeta(type(BaseModel)): - def __new__(mcs, name, bases, namespace, **kwargs): + def __call__(cls, *args, **kwargs): # noqa: ANN002 + if cls is Signature: + return make_signature(*args, **kwargs) + return super().__call__(*args, **kwargs) + + def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804 # Set `str` as the default type for all fields raw_annotations = namespace.get("__annotations__", {}) - for name, field in namespace.items(): + for name, _field in namespace.items(): if not name.startswith("__") and name not in raw_annotations: raw_annotations[name] = str namespace["__annotations__"] = raw_annotations # Let Pydantic do its thing - cls = super().__new__(mcs, name, bases, namespace, **kwargs) + cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs) if cls.__doc__ is None: cls.__doc__ = _default_instructions(cls) @@ -69,17 +75,20 @@ def signature(cls) -> str: def instructions(cls) -> str: return getattr(cls, "__doc__", "") - def with_instructions(cls, instructions: str): + def with_instructions(cls, instructions: str) -> Type["Signature"]: return Signature(cls.fields, instructions) @property - def fields(cls): + def fields(cls) -> dict[str, FieldInfo]: # Make sure to give input fields before output fields return {**cls.input_fields, **cls.output_fields} - def with_updated_fields(cls, name, type_=None, **kwargs): - """Returns a new Signature type with the field, name, updated - with fields[name].json_schema_extra[key] = value.""" + def with_updated_fields(cls, name, type_=None, **kwargs) -> Type["Signature"]: + """Update the field, name, in a new Signature type. + + Returns a new Signature type with the field, name, updated + with fields[name].json_schema_extra[key] = value. + """ fields_copy = deepcopy(cls.fields) fields_copy[name].json_schema_extra = { **fields_copy[name].json_schema_extra, @@ -90,27 +99,23 @@ def with_updated_fields(cls, name, type_=None, **kwargs): return Signature(fields_copy, cls.instructions) @property - def input_fields(cls): + def input_fields(cls) -> dict[str, FieldInfo]: return cls._get_fields_with_type("input") @property - def output_fields(cls): + def output_fields(cls) -> dict[str, FieldInfo]: return cls._get_fields_with_type("output") - def _get_fields_with_type(cls, field_type): - return { - k: v - for k, v in cls.model_fields.items() - if v.json_schema_extra["__dspy_field_type"] == field_type - } + def _get_fields_with_type(cls, field_type) -> dict[str, FieldInfo]: + return {k: v for k, v in cls.model_fields.items() if v.json_schema_extra["__dspy_field_type"] == field_type} - def prepend(cls, name, field, type_=None): + def prepend(cls, name, field, type_=None) -> Type["Signature"]: return cls.insert(0, name, field, type_) - def append(cls, name, field, type_=None): + def append(cls, name, field, type_=None) -> Type["Signature"]: return cls.insert(-1, name, field, type_) - def insert(cls, index: int, name: str, field, type_: Type = None): + def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signature"]: # It's posisble to set the type as annotation=type in pydantic.Field(...) # But this may be annoying for users, so we allow them to pass the type if type_ is None: @@ -122,11 +127,7 @@ def insert(cls, index: int, name: str, field, type_: Type = None): output_fields = list(cls.output_fields.items()) # Choose the list to insert into based on the field type - lst = ( - input_fields - if field.json_schema_extra["__dspy_field_type"] == "input" - else output_fields - ) + lst = input_fields if field.json_schema_extra["__dspy_field_type"] == "input" else output_fields # We support negative insert indices if index < 0: index += len(lst) + 1 @@ -137,83 +138,7 @@ def insert(cls, index: int, name: str, field, type_: Type = None): new_fields = dict(input_fields + output_fields) return Signature(new_fields, cls.instructions) - def _parse_signature(cls, signature: str) -> Tuple[Type, Field]: - pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$" - if not re.match(pattern, signature): - raise ValueError(f"Invalid signature format: '{signature}'") - - fields = {} - inputs_str, outputs_str = map(str.strip, signature.split("->")) - inputs = [v.strip() for v in inputs_str.split(",") if v.strip()] - outputs = [v.strip() for v in outputs_str.split(",") if v.strip()] - for name in inputs: - fields[name] = (str, InputField()) - for name in outputs: - fields[name] = (str, OutputField()) - - return fields - - def __call__( - cls, - signature: Union[str, Dict[str, Tuple[type, FieldInfo]]], - instructions: str = None, - ): - """ - Creates a new Signature type with the given fields and instructions. - Note: - Even though we're calling a type, we're not making an instance of the type. - In general we don't allow instances of Signature types to be made. The call - syntax is only for your convenience. - Parameters: - signature: Format: "input1, input2 -> output1, output2" - instructions: Optional prompt for the signature. - """ - - if isinstance(signature, str): - fields = cls._parse_signature(signature) - else: - fields = signature - - # Validate the fields, this is important because we sometimes forget the - # slightly unintuitive syntax with tuples of (type, Field) - fixed_fields = {} - for name, type_field in fields.items(): - assert isinstance( - name, str, - ), f"Field names must be strings, not {type(name)}" - if isinstance(type_field, FieldInfo): - type_ = type_field.annotation - field = type_field - else: - assert isinstance( - type_field, tuple, - ), f"Field values must be tuples, not {type(type_field)}" - type_, field = type_field - # It might be better to be explicit about the type, but it currently would break - # program of thought and teleprompters, so we just silently default to string. - if type_ is None: - type_ = str - assert isinstance(type_, type) or isinstance( - typing.get_origin(type_), type, - ), f"Field types must be types, not {type(type_)}" - assert isinstance( - field, FieldInfo, - ), f"Field values must be Field instances, not {type(field)}" - fixed_fields[name] = (type_, field) - - # Fixing the fields shouldn't change the order - assert list(fixed_fields.keys()) == list(fields.keys()) - - # Default prompt when no instructions are provided - if instructions is None: - sig = Signature(signature, "") # Simple way to parse input/output fields - instructions = _default_instructions(sig) - - signature = create_model("Signature", __base__=Signature, **fixed_fields) - signature.__doc__ = instructions - return signature - - def equals(cls, other): + def equals(cls, other) -> bool: """Compare the JSON schema of two Pydantic models.""" if not isinstance(other, type) or not issubclass(other, BaseModel): return False @@ -226,30 +151,44 @@ def equals(cls, other): return True def __repr__(cls): - """ - Outputs something on the form: + """Output a representation of the signature. + + Uses the form: Signature(question, context -> answer question: str = InputField(desc="..."), context: List[str] = InputField(desc="..."), answer: int = OutputField(desc="..."), - ) + ). """ field_reprs = [] for name, field in cls.fields.items(): field_reprs.append(f"{name} = Field({field})") field_repr = "\n ".join(field_reprs) - return ( - f"Signature({cls.signature}\n" - f" instructions={repr(cls.instructions)}\n" - f" {field_repr}\n)" - ) + return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)" class Signature(BaseModel, metaclass=SignatureMeta): + """A signature for a predictor. + + You typically subclass it, like this: + class MySignature(Signature): + input: str = InputField(desc="...") + output: int = OutputField(desc="...") + + You can call Signature("input1, input2 -> output1, output2") to create a new signature type. + You can also include instructions, Signature("input -> output", "This is a test"). + But it's generally better to use the make_signature function. + + If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), + or a signature, you can use the ensure_signature function. + + For compatibility with the legacy dsp format, you can use the signature_to_template function. + """ + pass -def ensure_signature(signature): +def ensure_signature(signature: Union[str, Type[Signature]]) -> Signature: if signature is None: return None if isinstance(signature, str): @@ -257,19 +196,146 @@ def ensure_signature(signature): return signature -def infer_prefix(attribute_name: str) -> str: - """Infers a prefix from an attribute name.""" +def make_signature( + signature: Union[str, Dict[str, Tuple[type, FieldInfo]]], + instructions: str = None, + signature_name: str = "StringSignature", +) -> Type[Signature]: + """Create a new Signature type with the given fields and instructions. + + Note: + Even though we're calling a type, we're not making an instance of the type. + In general, instances of Signature types are not allowed to be made. The call + syntax is provided for convenience. + + Args: + signature: The signature format, specified as "input1, input2 -> output1, output2". + instructions: An optional prompt for the signature. + signature_name: An optional name for the new signature type. + """ + fields = _parse_signature(signature) if isinstance(signature, str) else signature + + # Validate the fields, this is important because we sometimes forget the + # slightly unintuitive syntax with tuples of (type, Field) + fixed_fields = {} + for name, type_field in fields.items(): + if not isinstance(name, str): + raise ValueError(f"Field names must be strings, not {type(name)}") + if isinstance(type_field, FieldInfo): + type_ = type_field.annotation + field = type_field + else: + if not isinstance(type_field, tuple): + raise ValueError(f"Field values must be tuples, not {type(type_field)}") + type_, field = type_field + # It might be better to be explicit about the type, but it currently would break + # program of thought and teleprompters, so we just silently default to string. + if type_ is None: + type_ = str + if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type): + raise ValueError(f"Field types must be types, not {type(type_)}") + if not isinstance(field, FieldInfo): + raise ValueError(f"Field values must be Field instances, not {type(field)}") + fixed_fields[name] = (type_, field) + + # Fixing the fields shouldn't change the order + assert list(fixed_fields.keys()) == list(fields.keys()) # noqa: S101 + + # Default prompt when no instructions are provided + if instructions is None: + sig = Signature(signature, "") # Simple way to parse input/output fields + instructions = _default_instructions(sig) + + return create_model( + signature_name, + __base__=Signature, + __doc__=instructions, + **fixed_fields, + ) + + +def _parse_signature(signature: str) -> Tuple[Type, Field]: + pattern = r"^\s*[\w\s,:]+\s*->\s*[\w\s,:]+\s*$" + if not re.match(pattern, signature): + raise ValueError(f"Invalid signature format: '{signature}'") + + fields = {} + inputs_str, outputs_str = map(str.strip, signature.split("->")) + inputs = [v.strip() for v in inputs_str.split(",") if v.strip()] + outputs = [v.strip() for v in outputs_str.split(",") if v.strip()] + for name_type in inputs: + name, type_ = _parse_named_type_node(name_type) + fields[name] = (type_, InputField()) + for name_type in outputs: + name, type_ = _parse_named_type_node(name_type) + fields[name] = (type_, OutputField()) + + return fields + + +def _parse_named_type_node(node, names=None) -> Any: + parts = node.split(":") + if len(parts) == 1: + return parts[0], str + name, type_str = parts + type_ = _parse_type_node(ast.parse(type_str), names) + return name, type_ + + +def _parse_type_node(node, names=None) -> Any: + """Recursively parse an AST node representing a type annotation. + without using structural pattern matching introduced in Python 3.10. + """ + if names is None: + names = {} + + if isinstance(node, ast.Module): + body = node.body + if len(body) != 1: + raise ValueError(f"Code is not syntactically valid: {node}") + return _parse_type_node(body[0], names) + + if isinstance(node, ast.Expr): + value = node.value + return _parse_type_node(value, names) + + if isinstance(node, ast.Name): + id_ = node.id + if id_ in names: + return names[id_] + for type_ in [int, str, float, bool, list, tuple, dict]: + if type_.__name__ == id_: + return type_ + + elif isinstance(node, ast.Subscript): + base_type = _parse_type_node(node.value, names) + arg_type = _parse_type_node(node.slice, names) + return base_type[arg_type] + + elif isinstance(node, ast.Tuple): + elts = node.elts + return tuple(_parse_type_node(elt, names) for elt in elts) + + raise ValueError(f"Code is not syntactically valid: {node}") + + +def infer_prefix(attribute_name: str) -> str: + """Infer a prefix from an attribute name.""" # Convert camelCase to snake_case, but handle sequences of capital letters properly s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", attribute_name) intermediate_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1) # Insert underscores around numbers to ensure spaces in the final output with_underscores_around_numbers = re.sub( - r"([a-zA-Z])(\d)", r"\1_\2", intermediate_name, + r"([a-zA-Z])(\d)", + r"\1_\2", + intermediate_name, ) with_underscores_around_numbers = re.sub( - r"(\d)([a-zA-Z])", r"\1_\2", with_underscores_around_numbers, + r"(\d)([a-zA-Z])", + r"\1_\2", + with_underscores_around_numbers, ) # Convert snake_case to 'Proper Title Case', but ensure acronyms are uppercased diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index a0c997145b..1f6390536a 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -1,6 +1,6 @@ import random +from typing import Union from dsp.modules import LM -from typing import List, Union, Dict import numpy as np from dsp.utils.utils import dotdict import re @@ -9,9 +9,9 @@ class DummyLM(LM): """Dummy language model for unit testing purposes.""" - def __init__(self, answers: Union[List[str], Dict[str,str]], follow_examples: bool = False): - """ - Initializes the dummy language model. + def __init__(self, answers: Union[list[str], dict[str, str]], follow_examples: bool = False): + """Initializes the dummy language model. + Parameters: - answers: A list of strings or a dictionary with string keys and values. - follow_examples: If True, and the prompt contains an example exactly equal to the prompt, @@ -24,7 +24,7 @@ def __init__(self, answers: Union[List[str], Dict[str,str]], follow_examples: bo self.answers = answers self.follow_examples = follow_examples - def basic_request(self, prompt, n=1, **kwargs): + def basic_request(self, prompt, n=1, **kwargs) -> dict[str, list[dict[str, str]]]: """Generates a dummy response based on the prompt.""" dummy_response = {"choices": []} for _ in range(n): @@ -55,12 +55,14 @@ def basic_request(self, prompt, n=1, **kwargs): answer = "No more responses" # Mimic the structure of a real language model response. - dummy_response["choices"].append({ - "text": answer, - "finish_reason": "simulated completion", - }) - - RED, GREEN, RESET = '\033[91m', '\033[92m', '\033[0m' + dummy_response["choices"].append( + { + "text": answer, + "finish_reason": "simulated completion", + }, + ) + + RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m" print("=== DummyLM ===") print(prompt, end="") print(f"{RED}{answer}{RESET}") @@ -77,67 +79,68 @@ def basic_request(self, prompt, n=1, **kwargs): return dummy_response - def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): + def __call__(self, prompt, _only_completed=True, _return_sorted=False, **kwargs): """Retrieves dummy completions.""" response = self.basic_request(prompt, **kwargs) choices = response["choices"] # Filter choices and return text completions. - completions = [choice["text"] for choice in choices] - - return completions + return [choice["text"] for choice in choices] - def get_convo(self, index): - """Get the prompt + anwer from the ith message""" - return self.history[index]['prompt'] \ - + " " \ - + self.history[index]['response']['choices'][0]['text'] + def get_convo(self, index) -> str: + """Get the prompt + anwer from the ith message.""" + return self.history[index]["prompt"] + " " + self.history[index]["response"]["choices"][0]["text"] -def dummy_rm(passages=()): +def dummy_rm(passages=()) -> callable: if not passages: - def inner(query:str, *, k:int, **kwargs): + + def inner(query: str, *, k: int, **kwargs): assert False, "No passages defined" + return inner max_length = max(map(len, passages)) + 100 vectorizer = DummyVectorizer(max_length) passage_vecs = vectorizer(passages) - def inner(query:str, *, k:int, **kwargs): + + def inner(query: str, *, k: int, **kwargs): assert k <= len(passages) query_vec = vectorizer([query])[0] scores = passage_vecs @ query_vec largest_idx = (-scores).argsort()[:k] - #return dspy.Prediction(passages=[passages[i] for i in largest_idx]) + # return dspy.Prediction(passages=[passages[i] for i in largest_idx]) return [dotdict(dict(long_text=passages[i])) for i in largest_idx] + return inner class DummyVectorizer: - """Simple vectorizer based on n-grams""" + """Simple vectorizer based on n-grams.""" + def __init__(self, max_length=100, n_gram=2): self.max_length = max_length self.n_gram = n_gram self.P = 10**9 + 7 # A large prime number random.seed(123) self.coeffs = [random.randrange(1, self.P) for _ in range(n_gram)] - + def _hash(self, gram): - """Hashes a string using a polynomial hash function""" + """Hashes a string using a polynomial hash function.""" h = 1 for coeff, c in zip(self.coeffs, gram): h = h * coeff + ord(c) h %= self.P return h % self.max_length - def __call__(self, texts: List[str]) -> np.ndarray: + def __call__(self, texts: list[str]) -> np.ndarray: vecs = [] for text in texts: - grams = [text[i:i+self.n_gram] for i in range(len(text) - self.n_gram + 1)] + grams = [text[i : i + self.n_gram] for i in range(len(text) - self.n_gram + 1)] vec = [0] * self.max_length for gram in grams: vec[self._hash(gram)] += 1 vecs.append(vec) - + vecs = np.array(vecs, dtype=np.float32) vecs -= np.mean(vecs, axis=1, keepdims=True) vecs /= np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-10 # Added epsilon to avoid division by zero diff --git a/pyproject.toml b/pyproject.toml index c1ad1e8210..1c0f759684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,6 +218,8 @@ select = [ "ERA", # pandas-vet "PD", + # avoid shadowing + "PLW", ] ignore = [ "D100", @@ -232,28 +234,14 @@ ignore = [ "ANN003", # utf-8 encoding skip "UP009", - # First argument of a method should be named `self` - "N805", - # 1 blank line required between summary line and description - "D205", # Missing return type annotation for special method `__init__` "ANN204", - # Avoid using the generic variable name `df` for DataFrames - "PD901", - # Unnecessary assignment to `df` before `return` statement - "RET504", - # commented code - "ERA001", # Star-arg unpacking after a keyword argument is strongly discouraged "B026", # Missing type annotation for function argument `self` "ANN001", # Dynamically typed expressions (typing.Any) are disallowed in `wrapper` "ANN401", - # Unnecessary `elif` after `return` statement - "RET505", - # Within an `except` clause, raise exceptions with `raise - "B904", # We don't need docstrings for every method "ANN202", "D107", diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index f515f93266..cabbc7cd09 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -11,6 +11,7 @@ from dspy.functional import predictor, cot, FunctionalModule, TypedPredictor, functional from dspy.primitives.example import Example from dspy.teleprompt.bootstrap import BootstrapFewShot +from dspy.teleprompt.vanilla import LabeledFewShot from dspy.utils.dummies import DummyLM @@ -35,9 +36,7 @@ def hard_questions(topics: List[str]) -> List[str]: pass expected = ["What is the speed of light?", "What is the speed of sound?"] - lm = DummyLM( - ['{"value": ["What is the speed of light?", "What is the speed of sound?"]}'] - ) + lm = DummyLM(['{"value": ["What is the speed of light?", "What is the speed of sound?"]}']) dspy.settings.configure(lm=lm) question = hard_questions(topics=["Physics", "Music"]) @@ -88,9 +87,7 @@ def test_simple_class(): class Answer(pydantic.BaseModel): value: float certainty: float - comments: List[str] = pydantic.Field( - description="At least two comments about the answer" - ) + comments: List[str] = pydantic.Field(description="At least two comments about the answer") class QA(FunctionalModule): @predictor @@ -229,9 +226,7 @@ def simple_metric(example, prediction, trace=None): lm = DummyLM(["blue", "Ring-ding-ding-ding-dingeringeding!"], follow_examples=True) dspy.settings.configure(lm=lm, trace=[]) - bootstrap = BootstrapFewShot( - metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1 - ) + bootstrap = BootstrapFewShot(metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1) compiled_student = bootstrap.compile(student, teacher=teacher, trainset=trainset) lm.inspect_history(n=2) @@ -295,7 +290,7 @@ def flight_information(email: str) -> TravelInformation: # Example with a bad origin code. '{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}', # Example to help the model understand - '{...}', + "{...}", # Fixed '{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}', ] @@ -344,9 +339,9 @@ def flight_information(email: str) -> TravelInformation: [ # First origin is wrong, then destination, then all is good '{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}', - '{...}', # Example to help the model understand + "{...}", # Example to help the model understand '{"origin": "JFK", "destination": "LA0", "date": "2022-12-25"}', - '{...}', # Example to help the model understand + "{...}", # Example to help the model understand '{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}', ] ) @@ -447,3 +442,93 @@ def test(input: Annotated[str, Field(description="description")]) -> Annotated[f output = test(input="input") assert output == 0.5 + + +def test_multiple_outputs(): + lm = DummyLM([str(i) for i in range(100)]) + dspy.settings.configure(lm=lm) + + test = TypedPredictor("input -> output") + output = test(input="input", config=dict(n=3)).completions.output + assert output == ["0", "1", "2"] + + +def test_multiple_outputs_int(): + lm = DummyLM([str(i) for i in range(100)]) + dspy.settings.configure(lm=lm) + + class TestSignature(dspy.Signature): + input: int = dspy.InputField() + output: int = dspy.OutputField() + + test = TypedPredictor(TestSignature) + + output = test(input=8, config=dict(n=3)).completions.output + assert output == [0, 1, 2] + + +def test_parse_type_string(): + lm = DummyLM([str(i) for i in range(100)]) + dspy.settings.configure(lm=lm) + + test = TypedPredictor("input:int -> output:int") + + output = test(input=8, config=dict(n=3)).completions.output + assert output == [0, 1, 2] + + +def test_fields_on_base_signature(): + class SimpleOutput(dspy.Signature): + output: float = dspy.OutputField(gt=0, lt=1) + + lm = DummyLM( + [ + "2.1", # Bad output + "0.5", # Good output + ] + ) + dspy.settings.configure(lm=lm) + + predictor = TypedPredictor(SimpleOutput) + + assert predictor().output == 0.5 + + +def test_synthetic_data_gen(): + class SyntheticFact(BaseModel): + fact: str = Field(..., description="a statement") + varacity: bool = Field(..., description="is the statement true or false") + + class ExampleSignature(dspy.Signature): + """Generate an example of a synthetic fact.""" + + fact: SyntheticFact = dspy.OutputField() + + lm = DummyLM( + [ + '{"fact": "The sky is blue", "varacity": true}', + '{"fact": "The sky is green", "varacity": false}', + '{"fact": "The sky is red", "varacity": true}', + '{"fact": "The earth is flat", "varacity": false}', + '{"fact": "The earth is round", "varacity": true}', + '{"fact": "The earth is a cube", "varacity": false}', + ] + ) + dspy.settings.configure(lm=lm) + + generator = TypedPredictor(ExampleSignature) + examples = generator(config=dict(n=3)) + for ex in examples.completions.fact: + assert isinstance(ex, SyntheticFact) + assert examples.completions.fact[0] == SyntheticFact(fact="The sky is blue", varacity=True) + + # If you have examples and want more + existing_examples = [ + dspy.Example(fact="The sky is blue", varacity=True), + dspy.Example(fact="The sky is green", varacity=False), + ] + trained = LabeledFewShot().compile(student=generator, trainset=existing_examples) + + augmented_examples = trained(config=dict(n=3)) + for ex in augmented_examples.completions.fact: + assert isinstance(ex, SyntheticFact) diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index b093258540..554d9b1274 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -39,12 +39,8 @@ class TestSignature(Signature): input = InputField(prefix="Modified:") output = OutputField() - assert ( - TestSignature.input_fields["input"].json_schema_extra["prefix"] == "Modified:" - ) - assert ( - TestSignature.output_fields["output"].json_schema_extra["prefix"] == "Output:" - ) + assert TestSignature.input_fields["input"].json_schema_extra["prefix"] == "Modified:" + assert TestSignature.output_fields["output"].json_schema_extra["prefix"] == "Output:" def test_signature_parsing(): @@ -69,10 +65,7 @@ def test_with_updated_field(): assert signature1 is not signature2, "The type should be immutable" for key in signature1.fields.keys(): if key != "input1": - assert ( - signature1.fields[key].json_schema_extra - == signature2.fields[key].json_schema_extra - ) + assert signature1.fields[key].json_schema_extra == signature2.fields[key].json_schema_extra assert signature1.instructions == signature2.instructions @@ -89,6 +82,8 @@ def test_instructions_signature(): def test_signature_instructions(): sig1 = Signature("input1 -> output1", instructions="This is a test") assert sig1.instructions == "This is a test" + sig2 = Signature("input1 -> output1", "This is a test") + assert sig2.instructions == "This is a test" def test_signature_instructions_none(): @@ -97,18 +92,14 @@ def test_signature_instructions_none(): def test_signature_from_dict(): - signature = Signature( - {"input1": InputField(), "input2": InputField(), "output": OutputField()} - ) + signature = Signature({"input1": InputField(), "input2": InputField(), "output": OutputField()}) for k in ["input1", "input2", "output"]: assert k in signature.fields assert signature.fields[k].annotation == str def test_signature_from_dict(): - signature = Signature( - {"input1": InputField(), "input2": InputField(), "output": OutputField()} - ) + signature = Signature({"input1": InputField(), "input2": InputField(), "output": OutputField()}) assert "input1" in signature.input_fields assert "input2" in signature.input_fields assert "output" in signature.output_fields @@ -164,3 +155,22 @@ def test_infer_prefix(): assert infer_prefix("URLAddress") == "URL Address" assert infer_prefix("isHTTPSecure") == "Is HTTP Secure" assert infer_prefix("isHTTPSSecure123") == "Is HTTPS Secure 123" + + +def test_insantiating(): + sig = Signature("input -> output") + assert issubclass(sig, Signature) + assert sig.__name__ == "StringSignature" + value = sig(input="test", output="test") + assert isinstance(value, sig) + + +def test_insantiating2(): + class SubSignature(Signature): + input = InputField() + output = OutputField() + + assert issubclass(SubSignature, Signature) + assert SubSignature.__name__ == "SubSignature" + value = SubSignature(input="test", output="test") + assert isinstance(value, SubSignature)