diff --git a/dsp/templates/template_v2.py b/dsp/templates/template_v2.py index 236fd1d64e..0df17185bb 100644 --- a/dsp/templates/template_v2.py +++ b/dsp/templates/template_v2.py @@ -72,6 +72,8 @@ def query(self, example: Example, is_demo: bool = False) -> str: """Retrieves the input variables from the example and formats them into a query string.""" result: list[str] = [] + # If not a demo, find the last field that doesn't have a value set in `example` and set it to "" + # This creates the "Output:" prefix at the end of the prompt. if not is_demo: has_value = [ field.input_variable in example @@ -80,40 +82,40 @@ def query(self, example: Example, is_demo: bool = False) -> str: for field in self.fields ] - for i in range(1, len(has_value)): - if has_value[i - 1] and not any(has_value[i:]): - example[self.fields[i].input_variable] = "" - break + # If there are no inputs, set the first field to "" + if not any(has_value): + example[self.fields[0].input_variable] = "" + # Otherwise find the first field without a value. + else: + for i in range(1, len(has_value)): + if has_value[i - 1] and not any(has_value[i:]): + example[self.fields[i].input_variable] = "" + break for field in self.fields: - if ( - field.input_variable in example - and example[field.input_variable] is not None - ): + if field.input_variable in example and example[field.input_variable] is not None: if field.input_variable in self.format_handlers: format_handler = self.format_handlers[field.input_variable] else: + def format_handler(x): assert type(x) == str, f"Need format_handler for {field.input_variable} of type {type(x)}" return " ".join(x.split()) formatted_value = format_handler(example[field.input_variable]) - separator = '\n' if field.separator == ' ' and '\n' in formatted_value else field.separator + separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator result.append( f"{field.name}{separator}{formatted_value}", ) - if self._has_augmented_guidelines() and (example.get('augmented', False)): + if self._has_augmented_guidelines() and (example.get("augmented", False)): return "\n\n".join([r for r in result if r]) return "\n".join([r for r in result if r]) def guidelines(self, show_guidelines=True) -> str: """Returns the task guidelines as described in the lm prompt""" - if (not show_guidelines) or ( - hasattr(dsp.settings, "show_guidelines") - and not dsp.settings.show_guidelines - ): + if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines): return "" result = "Follow the following format.\n\n" @@ -128,11 +130,13 @@ def guidelines(self, show_guidelines=True) -> str: def _has_augmented_guidelines(self): return len(self.fields) > 3 or any( - ("\n" in field.separator) or ('\n' in field.description) for field in self.fields + ("\n" in field.separator) or ("\n" in field.description) for field in self.fields ) def extract( - self, example: Union[Example, dict[str, Any]], raw_pred: str, + self, + example: Union[Example, dict[str, Any]], + raw_pred: str, ) -> Example: """Extracts the answer from the LM raw prediction using the template structure @@ -149,10 +153,7 @@ def extract( idx = 0 while idx < len(self.fields): - if ( - self.fields[idx].input_variable not in example - or example[self.fields[idx].input_variable] is None - ): + if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None: break idx += 1 @@ -166,8 +167,8 @@ def extract( if offset >= 0: if dspy.settings.release >= 20231003: - example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip('---').strip() - raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip('---').strip() + example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip() + raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip() else: example[self.fields[idx].output_variable] = raw_pred[:offset].strip() raw_pred = raw_pred[offset + len(next_field_name) :].strip() @@ -175,7 +176,7 @@ def extract( idx += 1 else: if dspy.settings.release >= 20231003: - example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip() + example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip() else: example[self.fields[idx].output_variable] = raw_pred.strip() @@ -187,7 +188,7 @@ def extract( assert idx == len(self.fields) - 1, (idx, len(self.fields)) if dspy.settings.release >= 20231003: - example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip() + example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip() else: example[self.fields[idx].output_variable] = raw_pred.strip() @@ -198,7 +199,7 @@ def extract( def __call__(self, example, show_guidelines=True) -> str: example = dsp.Example(example) - if hasattr(dsp.settings, 'query_only') and dsp.settings.query_only: + if hasattr(dsp.settings, "query_only") and dsp.settings.query_only: return self.query(example) # The training data should not contain the output variable @@ -209,29 +210,20 @@ def __call__(self, example, show_guidelines=True) -> str: self.query(demo, is_demo=True) for demo in example.demos if ( - (not demo.get('augmented', False)) + (not demo.get("augmented", False)) and ( # validate that the training example has the same primitive input var as the template - self.fields[-1].input_variable in demo - and demo[self.fields[-1].input_variable] is not None + self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None ) ) ] - ademos = [ - self.query(demo, is_demo=True) - for demo in example.demos - if demo.get('augmented', False) - ] + ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)] # Move the rdemos to ademos if rdemo has all the fields filled in rdemos_ = [] new_ademos = [] for rdemo in rdemos: - if all( - (field.name in rdemo) - for field in self.fields - if field.input_variable in example - ): + if all((field.name in rdemo) for field in self.fields if field.input_variable in example): import dspy if dspy.settings.release >= 20230928: @@ -244,7 +236,6 @@ def __call__(self, example, show_guidelines=True) -> str: ademos = new_ademos + ademos rdemos = rdemos_ - long_query = self._has_augmented_guidelines() if long_query: @@ -253,10 +244,10 @@ def __call__(self, example, show_guidelines=True) -> str: query = self.query(example) # if it has more lines than fields - if len(query.split('\n')) > len(self.fields): + if len(query.split("\n")) > len(self.fields): long_query = True - if not example.get('augmented', False): + if not example.get("augmented", False): example["augmented"] = True query = self.query(example) diff --git a/dspy/__init__.py b/dspy/__init__.py index b61ea8a017..55dd6b6ffc 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -1,5 +1,3 @@ -# from .evaluation import * -# FIXME: import dsp from dsp.modules.hf_client import ChatModuleClient, HFClientSGLang, HFClientVLLM, HFServerTGI @@ -8,6 +6,9 @@ from .retrieve import * from .signatures import * +# Functional must be imported after primitives, predict and signatures +from .functional import * # isort: skip + settings = dsp.settings AzureOpenAI = dsp.AzureOpenAI diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 97bbdb8d89..44f9f5a5d5 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -1,16 +1,20 @@ import inspect import json +import os import typing -from typing import Annotated, List, Tuple +from typing import Annotated, List, Tuple # noqa: UP035 +import openai import pydantic +import ujson import dspy -from dsp.templates.utils import passages2text +from dsp.templates import passages2text from dspy.primitives.prediction import Prediction from dspy.signatures.signature import ensure_signature, make_signature -MAX_RETRIES = 3 +# Some improvement ideas: +# - Increase the temperature on error def predictor(func) -> dspy.Module: @@ -52,7 +56,7 @@ def __init__(self): self.__dict__[name] = attr.copy() -def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802 +def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802 """Just like TypedPredictor, but adds a ChainOfThought OutputField.""" signature = ensure_signature(signature) output_keys = ", ".join(signature.output_fields.keys()) @@ -64,17 +68,19 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802 desc="${produce the " + output_keys + "}. We ...", ), ), + max_retries=max_retries, ) class TypedPredictor(dspy.Module): - def __init__(self, signature): + def __init__(self, signature, max_retries=3): super().__init__() self.signature = ensure_signature(signature) self.predictor = dspy.Predict(signature) + self.max_retries = max_retries def copy(self) -> "TypedPredictor": - return TypedPredictor(self.signature) + return TypedPredictor(self.signature, self.max_retries) @staticmethod def _make_example(type_) -> str: @@ -98,9 +104,7 @@ def _make_example(type_) -> str: # library like https://pypi.org/project/polyfactory/ that's made exactly to do this. def _prepare_signature(self) -> dspy.Signature: - """Add formats and parsers to the signature fields, based on the type - annotations of the fields. - """ + """Add formats and parsers to the signature fields, based on the type annotations of the fields.""" signature = self.signature for name, field in self.signature.fields.items(): is_output = field.json_schema_extra["__dspy_field_type"] == "output" @@ -114,20 +118,36 @@ def _prepare_signature(self) -> dspy.Signature: format=lambda x: x if isinstance(x, str) else str(x), parser=type_, ) + elif False: + # TODO: I don't like forcing the model to write "value" in the output. + if not (inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel)): + type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel) + to_json = lambda x, type_=type_: type_(value=x).model_dump_json()[9:-1] # {"value":"123"} + from_json = lambda x, type_=type_: type_.model_validate_json('{"value":' + x + "}").value + schema = json.dumps(type_.model_json_schema()["properties"]["value"]) + else: + to_json = lambda x: x.model_dump_json() + from_json = lambda x, type_=type_: type_.model_validate_json(x) + schema = json.dumps(type_.model_json_schema()) else: # Anything else we wrap in a pydantic object - to_json = lambda x: x.model_dump_json() - from_json = lambda x, type_=type_: type_.model_validate_json(x) - if not (inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel)): + if not ( + inspect.isclass(type_) + and typing.get_origin(type_) not in (list, tuple) # To support Python 3.9 + and issubclass(type_, pydantic.BaseModel) + ): type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel) to_json = lambda x, type_=type_: type_(value=x).model_dump_json() from_json = lambda x, type_=type_: type_.model_validate_json(x).value + schema = json.dumps(type_.model_json_schema()) + else: + to_json = lambda x: x.model_dump_json() + from_json = lambda x, type_=type_: type_.model_validate_json(x) + schema = json.dumps(type_.model_json_schema()) signature = signature.with_updated_fields( name, desc=field.json_schema_extra.get("desc", "") - + ( - ". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema()) - ), + + (". Respond with a single JSON object. JSON Schema: " + schema), format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)), parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)), type_=type_, @@ -136,6 +156,15 @@ def _prepare_signature(self) -> dspy.Signature: format_ = lambda x: x if isinstance(x, str) else str(x) if type_ in (List[str], list[str], Tuple[str], tuple[str]): format_ = passages2text + # Special formatting for lists of known types. Maybe the output fields sohuld have this too? + elif typing.get_origin(type_) in (List, list, Tuple, tuple): + (inner_type,) = typing.get_args(type_) + if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel): + format_ = ( + lambda x: x if isinstance(x, str) else "[" + ",".join(i.model_dump_json() for i in x) + "]" + ) + else: + format_ = lambda x: x if isinstance(x, str) else json.dumps(x) elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): format_ = lambda x: x if isinstance(x, str) else x.model_dump_json() signature = signature.with_updated_fields(name, format=format_) @@ -147,41 +176,46 @@ def forward(self, **kwargs) -> dspy.Prediction: # We have to re-prepare the signature on every forward call, because the base # signature might have been modified by an optimizer or something like that. signature = self._prepare_signature() - for try_i in range(MAX_RETRIES): + for try_i in range(self.max_retries): result = self.predictor(**modified_kwargs, new_signature=signature) errors = {} parsed_results = [] # Parse the outputs - for i, completion in enumerate(result.completions): - try: - parsed = {} - for name, field in signature.output_fields.items(): + for completion in result.completions: + parsed = {} + for name, field in signature.output_fields.items(): + try: value = completion[name] parser = field.json_schema_extra.get("parser", lambda x: x) - completion[name] = parser(value) parsed[name] = parser(value) - # Instantiate the actual signature with the parsed values. - # This allow pydantic to validate the fields defined in the signature. + except (pydantic.ValidationError, ValueError) as e: + errors[name] = _format_error(e) + # If we can, we add an example to the error message + current_desc = field.json_schema_extra.get("desc", "") + i = current_desc.find("JSON Schema: ") + if i == -1: + continue # Only add examples to JSON objects + suffix, current_desc = current_desc[i:], current_desc[:i] + prefix = "You MUST use this format: " + if ( + try_i + 1 < self.max_retries + and prefix not in current_desc + and (example := self._make_example(field.annotation)) + ): + signature = signature.with_updated_fields( + name, + desc=current_desc + "\n" + prefix + example + "\n" + suffix, + ) + # No reason trying to parse the general signature, or run more completions, if we already have errors + if errors: + break + # Instantiate the actual signature with the parsed values. + # This allow pydantic to validate the fields defined in the signature. + try: _dummy = self.signature(**kwargs, **parsed) parsed_results.append(parsed) - except (pydantic.ValidationError, ValueError) as e: - errors[name] = _format_error(e) - # If we can, we add an example to the error message - current_desc = field.json_schema_extra.get("desc", "") - i = current_desc.find("JSON Schema: ") - if i == -1: - continue # Only add examples to JSON objects - suffix, current_desc = current_desc[i:], current_desc[:i] - prefix = "You MUST use this format: " - if ( - try_i + 1 < MAX_RETRIES - and prefix not in current_desc - and (example := self._make_example(field.annotation)) - ): - signature = signature.with_updated_fields( - name, - desc=current_desc + "\n" + prefix + example + "\n" + suffix, - ) + except pydantic.ValidationError as e: + errors["general"] = _format_error(e) if errors: # Add new fields for each error for name, error in errors.items(): @@ -208,7 +242,7 @@ def _format_error(error: Exception): if isinstance(error, pydantic.ValidationError): errors = [] for e in error.errors(): - fields = ", ".join(e["loc"]) + fields = ", ".join(map(str, e["loc"])) errors.append(f"{e['msg']}: {fields} (error type: {e['type']})") return "; ".join(errors) else: @@ -254,7 +288,7 @@ def _unwrap_json(output): output = output[7:-3].strip() if not output.startswith("{") or not output.endswith("}"): raise ValueError("json output should start and end with { and }") - return output + return ujson.dumps(ujson.loads(output)) # ujson is a bit more robust than the standard json ################################################################################ diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index fe6beb4483..37fff10c49 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -27,7 +27,7 @@ def dump_state(self): state["signature_instructions"] = self.signature.instructions *_, last_key = self.signature.fields.keys() - state["signature_prefix"] = self.signature.fields[last_key].json_schema_extra['prefix'] + state["signature_prefix"] = self.signature.fields[last_key].json_schema_extra["prefix"] return state @@ -72,7 +72,8 @@ def forward(self, **kwargs): # print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.") # All of the other kwargs are presumed to fit a prefix of the signature. - + # That is, they are input variables for the bottom most generation, so + # we place them inside the input - x - together with the demos. x = dsp.Example(demos=demos, **kwargs) if new_signature is not None: @@ -103,7 +104,8 @@ def forward(self, **kwargs): for field in template.fields: if field.output_variable not in kwargs.keys(): completions[-1][field.output_variable] = getattr( - c, field.output_variable, + c, + field.output_variable, ) pred = Prediction.from_completions(completions, signature=signature) diff --git a/dspy/primitives/module.py b/dspy/primitives/module.py index bee80338f5..90908e5c3a 100644 --- a/dspy/primitives/module.py +++ b/dspy/primitives/module.py @@ -1,4 +1,5 @@ import copy +from collections.abc import Generator import ujson @@ -9,7 +10,7 @@ def __init__(self): def named_parameters(self): """ - Unlike PyTorch, handles (non-recursive) lists of parameters too. + Unlike PyTorch, handles (non-recursive) lists of parameters too. """ from dspy.predict.parameter import Parameter @@ -28,10 +29,10 @@ def add_parameter(param_name, param_value): elif isinstance(value, BaseModule): # When a sub-module is pre-compiled, keep it frozen. - if not getattr(value, '_compiled', False): + if not getattr(value, "_compiled", False): for sub_name, param in value.named_parameters(): add_parameter(f"{name}.{sub_name}", param) - + elif isinstance(value, (list, tuple)): for idx, item in enumerate(value): add_parameter(f"{name}[{idx}]", item) @@ -42,6 +43,12 @@ def add_parameter(param_name, param_value): return named_parameters + def named_sub_modules(self, root_name="base") -> Generator[tuple[str, "BaseModule"], None, None]: + yield root_name, self + for name, value in self.__dict__.items(): + if isinstance(value, BaseModule): + yield from value.named_sub_modules(root_name=f"{root_name}.{name}") + def parameters(self): return [param for _, param in self.named_parameters()] @@ -50,23 +57,23 @@ def deepcopy(self): def reset_copy(self): obj = copy.deepcopy(self) - + for param in obj.parameters(): param.reset() - + return obj - + def dump_state(self): return {name: param.dump_state() for name, param in self.named_parameters()} - + def load_state(self, state): for name, param in self.named_parameters(): param.load_state(state[name]) - + def save(self, path): with open(path, "w") as f: f.write(ujson.dumps(self.dump_state(), indent=2)) - + def load(self, path): with open(path) as f: self.load_state(ujson.loads(f.read())) diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index aa499d9085..45d51ab357 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -1,4 +1,3 @@ - import re from dspy.primitives.assertions import * @@ -17,7 +16,6 @@ class ProgramMeta(type): class Module(BaseModule, metaclass=ProgramMeta): - def _base_init(self): self._compiled = False @@ -30,12 +28,7 @@ def __call__(self, *args, **kwargs): def named_predictors(self): from dspy.predict.predict import Predict - named_parameters = self.named_parameters() - return [ - (name, param) - for name, param in named_parameters - if isinstance(param, Predict) - ] + return [(name, param) for name, param in self.named_parameters() if isinstance(param, Predict)] def predictors(self): return [param for _, param in self.named_predictors()] @@ -53,7 +46,7 @@ def map_named_predictors(self, func): for name, predictor in self.named_predictors(): set_attribute_by_name(self, name, func(predictor)) return self - + def activate_assertions(self, handler=backtrack_handler, **handler_args): """ Activates assertions for the module. diff --git a/dspy/signatures/field.py b/dspy/signatures/field.py index 7822c625c0..4e32714778 100644 --- a/dspy/signatures/field.py +++ b/dspy/signatures/field.py @@ -1,5 +1,10 @@ import pydantic +# The following arguments can be used in DSPy InputField and OutputField in addition +# to the standard pydantic.Field arguments. We just hope pydanitc doesn't add these, +# as it would give a name clash. +DSPY_FIELD_ARG_NAMES = ["desc", "prefix", "format", "parser", "__dspy_field_type"] + def move_kwargs(**kwargs): # Pydantic doesn't allow arbitrary arguments to be given to fields, @@ -10,7 +15,7 @@ def move_kwargs(**kwargs): pydantic_kwargs = {} json_schema_extra = {} for k, v in kwargs.items(): - if k in ["desc", "prefix", "format", "parser", "__dspy_field_type"]: + if k in DSPY_FIELD_ARG_NAMES: json_schema_extra[k] = v else: pydantic_kwargs[k] = v @@ -27,11 +32,7 @@ def OutputField(**kwargs): def new_to_old_field(field): - return ( - OldInputField - if field.json_schema_extra["__dspy_field_type"] == "input" - else OldOutputField - )( + return (OldInputField if field.json_schema_extra["__dspy_field_type"] == "input" else OldOutputField)( prefix=field.json_schema_extra["prefix"], desc=field.json_schema_extra["desc"], format=field.json_schema_extra.get("format"), diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 22bc6475d7..314d6c712b 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,5 +1,6 @@ import ast import re +import types import typing from copy import deepcopy from typing import Any, Dict, Tuple, Type, Union # noqa: UP035 @@ -42,6 +43,17 @@ def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804 # Let Pydantic do its thing cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs) + # If we don't have instructions, it might be because we are a derived generic type. + # In that case, we should inherit the instructions from the base class. + if cls.__doc__ is None: + for base in bases: + if isinstance(base, SignatureMeta): + doc = getattr(base, "__doc__", "") + if doc != "": + cls.__doc__ = doc + + # The more likely case is that the user has just not given us a type. + # In that case, we should default to the input/output format. if cls.__doc__ is None: cls.__doc__ = _default_instructions(cls) @@ -63,7 +75,7 @@ def _validate_fields(cls): field_type = extra.get("__dspy_field_type") if field_type not in ["input", "output"]: raise TypeError( - f"Field '{name}' in '{cls.__name__}' must be declared with InputField or OutputField.", + f"Field '{name}' in '{cls.__name__}' must be declared with InputField or OutputField. {field.json_schema_extra=}", ) @property @@ -168,24 +180,27 @@ def __repr__(cls): return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)" +# A signature for a predictor. +# +# You typically subclass it, like this: +# class MySignature(Signature): +# input: str = InputField(desc="...") # noqa: ERA001 +# output: int = OutputField(desc="...") # noqa: ERA001 +# +# You can call Signature("input1, input2 -> output1, output2") to create a new signature type. +# You can also include instructions, Signature("input -> output", "This is a test"). +# But it's generally better to use the make_signature function. +# +# If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), +# or a signature, you can use the ensure_signature function. +# +# For compatibility with the legacy dsp format, you can use the signature_to_template function. +# class Signature(BaseModel, metaclass=SignatureMeta): - """A signature for a predictor. - - You typically subclass it, like this: - class MySignature(Signature): - input: str = InputField(desc="...") - output: int = OutputField(desc="...") - - You can call Signature("input1, input2 -> output1, output2") to create a new signature type. - You can also include instructions, Signature("input -> output", "This is a test"). - But it's generally better to use the make_signature function. - - If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"), - or a signature, you can use the ensure_signature function. - - For compatibility with the legacy dsp format, you can use the signature_to_template function. - """ + "" # noqa: D419 + # Note: Don't put a docstring here, as it will become the default instructions + # for any signature that doesn't define it's own instructions. pass @@ -233,7 +248,8 @@ def make_signature( # program of thought and teleprompters, so we just silently default to string. if type_ is None: type_ = str - if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type): + # if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type): + if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias)): raise ValueError(f"Field types must be types, not {type(type_)}") if not isinstance(field, FieldInfo): raise ValueError(f"Field values must be Field instances, not {type(field)}") diff --git a/dspy/teleprompt/signature_opt_typed.py b/dspy/teleprompt/signature_opt_typed.py new file mode 100644 index 0000000000..1921f8cd1d --- /dev/null +++ b/dspy/teleprompt/signature_opt_typed.py @@ -0,0 +1,251 @@ +import textwrap +from typing import Generic, Literal, TypeVar + +import pydantic + +import dspy +from dspy import BaseModel +from dspy.functional.functional import TypedChainOfThought, TypedPredictor +from dspy.signatures import Signature +from dspy.signatures.field import InputField, OutputField + +# TODO: +# - Parallelize the generation of new signatures when we have multiple predictors +# - Consider generating multiple new signatures at once, which we can test in parallel +# - Consider using the prompt optimizer to optimize the prompt optimizer :O + + +def make_info(signature: type[Signature]) -> BaseModel: + """Creates a SignatureInfo pydantic type, that describes the Signature. + + Returns an instnce of this type, with the instructions and field descriptions of the input type. + """ + # First, create the SignatureInfo type + fields = { + "instructions": (str, pydantic.Field(description="The instructions for the task")), + } + for name in signature.fields: + fields[name + "_prefix"] = (str, pydantic.Field(description=f"The prefix for {name}")) + fields[name + "_desc"] = (str, pydantic.Field(description=f"The description for {name}")) + SignatureInfo = pydantic.create_model( # noqa: N806 + f"SignatureInfo[{signature.__name__}]", + **fields, + ) + + # Add a method to convert the SignatureInfo back into a Signature + def to_signature(info): + new_signature = signature.with_instructions(info.instructions) + for name in signature.fields: + new_signature = new_signature.with_updated_fields( + name, + prefix=getattr(info, name + "_prefix"), + desc=getattr(info, name + "_desc"), + ) + return new_signature + + SignatureInfo.to_signature = to_signature + + # Finally, make an instance of the SignatureInfo type with the signature's + # default instructions and field descriptions + values = {"instructions": signature.instructions} + for name, field in signature.fields.items(): + values[name + "_prefix"] = field.json_schema_extra["prefix"] + values[name + "_desc"] = field.json_schema_extra["desc"] + return SignatureInfo(**values) + + +T = TypeVar("T", bound=BaseModel) + + +# Note: This function wouldn't be necessary if we could make the number of prompts a generic parameter of the class, +# but alas it seems like this isn't possible in Python right now. The main reason being that types and generics only +# live inside the type system, and can't be used to generate code at runtime. +def make_initial_signature(n_prompts: int) -> type[Signature]: + """Creates a GenerateInstructionInitial signature with the given number of initial prompts.""" + + class GenerateInstructionInitial(Signature, Generic[T]): + # TODO: Can we make textwrap default/automatic in all signatures? + __doc__ = textwrap.dedent("""\ + You are a creative instruction optimizer for large language models. + + I will give you a ``signature`` of fields (inputs and outputs) in English. + Your task is to propose variations of the signature that will lead a good language model. + + Be very creative and think out of the box. + You can use as long instructions as you want. + Consider using inspiration such as: + Openers: + - You are as smart as ChatGPT. + - You are highly intelligent. + - You are an expert mathematician. + - You are a professor of mathematics. + Task Descriptions: + - Be consise in your answer. + - Be as clear as possible. + - Use lots of creativity. + Closers: + - This will be fun! + - Take a deep breath and think carefully. + - I really need your help! + """) + + basic_signature: T = InputField() + proposed_signatures: list[T] = OutputField( + desc=f"A list of {n_prompts} very different variations of the basic signature", + min_items=n_prompts, + max_items=n_prompts, + ) + + return GenerateInstructionInitial + + +class GenerateSignature(dspy.Signature, Generic[T]): + __doc__ = textwrap.dedent("""\ + You are an instruction optimizer for large language models. + + I will give some task instructions I've tried, along with their corresponding validation scores. + - The instructions are arranged in order based on their scores, where higher scores indicate better quality. + - Your task is to propose a new instruction that will lead a good language model to perform the task even better. + - Be creative, and think out of the box. + - Don't repeat instructions, descriptions and prefixes that have already been attempted. + """) + + analysis: str = OutputField(desc="Consider what made the previous instructions good or bad.") + proposed_signature: T = OutputField(desc="A signature that will likely lead to a high score.") + score: float = OutputField(desc="The expected score for the new signature. Don't write anything after this number.") + + +def optimize_signature( + student, + evaluator, + n_iterations=10, + strategy: Literal["best", "last"] = "best", + sorted_order: Literal["increasing", "decreasing"] = "increasing", + # Formerly part of the constructor + prompt_model=None, + initial_prompts=2, + verbose=False, +) -> dspy.Program: + """Create a new program that is optimized for the given task. + + `student` is a program that needs to be optimized, + note that it may be zero-shot or already pre-optimized for demos != []. + + Parameters + ---------- + student : dspy.Program + The program to optimize. + evaluator : dspy.Evaluator + The evaluator to use to score the program. + n_iterations : int, optional + The number of iterations to run, by default 10 + strategy : Literal["best", "last"], optional + The strategy to use to select the final program, by default "best" + sorted_order : Literal["increasing", "decreasing"], optional + The order in which to sort the scores, by default "increasing" + prompt_model : dspy.LanguageModel, optional + The language model to use to generate prompts, by default None + initial_prompts : int, optional + The number of initial prompts to generate, by default 2. + Note that we also use the "plain" signature as a prompt, so the total number of prompts is initial_prompts + 1. + verbose : bool, optional + Whether to print debug information, by default False + + Notes: + ----- + We don't support temperatures, since it tends to break the typed generation. + """ + if n_iterations < 1 + initial_prompts: + raise ValueError("n_iterations must be at least 1 + initial_prompts") + + prompt_model = prompt_model or dspy.settings.lm + MyGenerateInstructionInitial = make_initial_signature(initial_prompts) # noqa: N806 + + module = student.deepcopy() + # In contrast to the original implementation, we don't want the Predict's, but the TypedPredictor's. + # This is because TypedPredictor changes the signature before it runs forward. So changing the signature + # on the Predicts won't work. + named_predictors = [ + (name, module) + for name, module in module.named_sub_modules() + if isinstance(module, TypedPredictor) and not getattr(module, "_compiled", False) + ] + if not named_predictors: + raise ValueError("No unfrozen/uncompiled TypedPredictors found in the module.") + if verbose: + print(f"Found {len(named_predictors)} typed predictors to optimize.") + + candidates = {} + scores = [] + + # First round, just use initial prompts + for name, p in named_predictors: + candidates[name] = [make_info(p.signature)] + + # Make some initial candidates + with dspy.settings.context(lm=prompt_model): + # TODO: Parallelize this + for name, _p in named_predictors: + if verbose: + print(f"Generating {initial_prompts} initial signatures for {name}...") + info = candidates[name][0] # Use initial info, to make sure types are identical + generator = TypedChainOfThought(MyGenerateInstructionInitial[type(info)]) + candidates[name] += generator( + basic_signature=info, + ).proposed_signatures + assert len(candidates[name]) == initial_prompts + 1 # Basic signature + initial prompts + + # Main loop of scoring + generating new candidates + for i in range(n_iterations): + if verbose: + print("\n" + "=" * 80) + print(f"Running eval iteration {i}...") + + # Install signatures + for name, p in named_predictors: + p.signature = candidates[name][i].to_signature() + + # Run evaluator given by user + score = evaluator(module) + scores.append(score) + + # If we are still testing initial prompts, continue + if i + 1 < len(next(iter(candidates.values()))): + continue + + # If we are done, there's no need to generate new candidates + if i + 1 == n_iterations: + break + + # Otherwise generate the next candidate + with dspy.settings.context(lm=prompt_model): + # TODO: Parallelize this + for name, _p in named_predictors: + SignatureInfo = type(candidates[name][0]) # noqa: N806 + generator = TypedPredictor(GenerateSignature[SignatureInfo]) + + demos = [ + dspy.Example( + proposed_signature=info, + score=sc, + ) + for info, sc in zip(candidates[name], scores) + ] + demos.sort(key=(lambda x: x.score), reverse=(sorted_order == "decreasing")) + generator.predictor.demos = demos + + if verbose: + print(f"Generating new signature for {name}...") + new_signature = generator().proposed_signature + candidates[name].append(new_signature) + + if strategy == "last": + return module + + if strategy == "best": + i = scores.index(max(scores)) + for name, p in named_predictors: + p.signature = candidates[name][i].to_signature() + return module + + raise ValueError(f"Invalid strategy: {strategy}") diff --git a/examples/functional/functional.ipynb b/examples/functional/functional.ipynb index cf7088a5d9..90b7f4501f 100644 --- a/examples/functional/functional.ipynb +++ b/examples/functional/functional.ipynb @@ -2,13 +2,15 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n", "Requirement already satisfied: datasets in /opt/homebrew/lib/python3.11/site-packages (2.14.7)\n", "Requirement already satisfied: numpy>=1.17 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (1.26.2)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (12.0.0)\n", @@ -39,9 +41,6 @@ "Requirement already satisfied: pytz>=2020.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: tzdata>=2022.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: six>=1.5 in /opt/homebrew/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.11 -m pip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } @@ -61,17 +60,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 26, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/homebrew/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "data": { "text/plain": [ @@ -82,7 +73,7 @@ " 'entry_point': 'has_close_elements'}" ] }, - "execution_count": 2, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -102,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -134,15 +125,21 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "Now parsing: '{\"code\": \"from typing import List\\\\n\\\\n\\\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\\\n \\\\\"\\\\\"\\\\\" Check if in given list of numbers, are any two numbers closer to each other than\\\\n given threshold.\\\\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\\\n False\\\\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\\\n True\\\\n \\\\\"\\\\\"\\\\\"\\\\n for i in range(len(numbers)):\\\\n for j in range(i+1, len(numbers)):\\\\n if abs(numbers[i] - numbers[j]) < threshold:\\\\n return True\\\\n return False\\\\n\"}'\n", + "Parsed: PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n given threshold.\\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n False\\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n True\\n \"\"\"\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\\n')\n", + "is this the problem?\n", + "kwargs={'prompt': PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n given threshold.\\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n False\\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n True\\n \"\"\"\\n'), 'test': PythonCode(code=\"\\n\\nMETADATA = {\\n 'author': 'jt',\\n 'dataset': 'test'\\n}\\n\\n\\ndef check(candidate):\\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\\n\"), 'entry_point': 'has_close_elements'}\n", + "parsed={'solution': PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n given threshold.\\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n False\\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n True\\n \"\"\"\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\\n')}\n", + "after wards\n", "Prediction(\n", - " solution=PythonCode(code='def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False')\n", + " solution=PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n given threshold.\\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n False\\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n True\\n \"\"\"\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\\n')\n", ")\n" ] } @@ -170,7 +167,7 @@ "# The signature is the main DSpy object. Note that we have types for the input and output fields,\n", "# which was not possible beofore.\n", "class CodeSignature(Signature):\n", - " prompt: str = InputField()\n", + " prompt: PythonCode = InputField()\n", " test: PythonCode = InputField()\n", " entry_point: str = InputField()\n", " solution: PythonCode = OutputField()\n", @@ -180,9 +177,7 @@ " prompt=PythonCode(code=ds['test'][0]['prompt']),\n", " test=PythonCode(code=ds['test'][0]['test']),\n", " entry_point=ds['test'][0]['entry_point']\n", - ")\n", - "\n", - "print(prediction)" + ")\n" ] }, { @@ -194,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -205,19 +200,40 @@ "\n", "\n", "\n", - "Make a very succinct json object that validates with the following schema\n", + "Given the fields `prompt`, `test`, `entry_point`, produce the fields `solution`.\n", "\n", "---\n", "\n", "Follow the following format.\n", "\n", - "Json Schema: ${json_schema}\n", - "Json Object: ${json_object}\n", + "Prompt: ${prompt}\n", + "\n", + "Test: ${test}\n", + "\n", + "Entry Point: ${entry_point}\n", + "\n", + "Past Error (solution): An error to avoid in the future\n", + "\n", + "Past Error (solution, 2): An error to avoid in the future\n", + "\n", + "Solution:\n", + "${solution}. Respond with a single JSON object. \n", + "You MUST use this format: {\"code\": \"print('Hello, World!')\"}\n", + "JSON Schema: {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n", "\n", "---\n", "\n", - "Json Schema: {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n", - "Json Object:\u001b[32m {\"code\": \"print('Hello, World!')\"}\u001b[0m\n", + "Prompt: code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n given threshold.\\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n False\\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n True\\n \"\"\"\\n'\n", + "\n", + "Test: {\"code\":\"\\n\\nMETADATA = {\\n 'author': 'jt',\\n 'dataset': 'test'\\n}\\n\\n\\ndef check(candidate):\\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\\n\"}\n", + "\n", + "Entry Point: has_close_elements\n", + "\n", + "Past Error (solution): Input should be a valid string: prompt (error type: string_type)\n", + "\n", + "Past Error (solution, 2): Value error, Code is not syntactically valid: unexpected character after line continuation character (, line 1): code (error type: value_error)\n", + "\n", + "Solution:\u001b[32m {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\"}\u001b[0m\n", "\n", "\n", "\n", @@ -237,7 +253,16 @@ "\n", "Entry Point: ${entry_point}\n", "\n", - "Solution: ${solution}. Respond with a single JSON object using the schema {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}. For example: {\"code\": \"print('Hello, World!')\"}\n", + "Past Error (solution): An error to avoid in the future\n", + "\n", + "Past Error (solution, 2): An error to avoid in the future\n", + "\n", + "Past Error (solution, 3): An error to avoid in the future\n", + "\n", + "Solution:\n", + "${solution}. Respond with a single JSON object. \n", + "You MUST use this format: {\"code\": \"print('Hello, World!')\"}\n", + "JSON Schema: {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n", "\n", "---\n", "\n", @@ -247,9 +272,13 @@ "\n", "Entry Point: has_close_elements\n", "\n", - "Solution:\u001b[32m {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n", + "Past Error (solution): Input should be a valid string: prompt (error type: string_type)\n", + "\n", + "Past Error (solution, 2): Value error, Code is not syntactically valid: unexpected character after line continuation character (, line 1): code (error type: value_error)\n", + "\n", + "Past Error (solution, 3): Input should be a valid string: prompt (error type: string_type)\n", "\n", - "{\"code\": \"from typing import List\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\"}\u001b[0m\n", + "Solution:\u001b[32m {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\\\n for i in range(len(numbers)):\\\\n for j in range(i+1, len(numbers)):\\\\n if abs(numbers[i] - numbers[j]) < threshold:\\\\n return True\\\\n return False\"}\u001b[0m\n", "\n", "\n", "\n", @@ -271,7 +300,16 @@ "\n", "Past Error (solution): An error to avoid in the future\n", "\n", - "Solution: ${solution}. Respond with a single JSON object using the schema {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}. For example: {\"code\": \"print('Hello, World!')\"}\n", + "Past Error (solution, 2): An error to avoid in the future\n", + "\n", + "Past Error (solution, 3): An error to avoid in the future\n", + "\n", + "Past Error (solution, 4): An error to avoid in the future\n", + "\n", + "Solution:\n", + "${solution}. Respond with a single JSON object. \n", + "You MUST use this format: {\"code\": \"print('Hello, World!')\"}\n", + "JSON Schema: {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n", "\n", "---\n", "\n", @@ -281,9 +319,15 @@ "\n", "Entry Point: has_close_elements\n", "\n", - "Past Error (solution): 1 validation error for PythonCode Invalid JSON: trailing characters at line 3 column 1 [type=json_invalid, input_value='{\"properties\": {\"code\": ...ue\\\\n return False\"}', input_type=str] For further information visit https://errors.pydantic.dev/2.5/v/json_invalid\n", + "Past Error (solution): Input should be a valid string: prompt (error type: string_type)\n", "\n", - "Solution:\u001b[32m {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\"}\u001b[0m\n", + "Past Error (solution, 2): Value error, Code is not syntactically valid: unexpected character after line continuation character (, line 1): code (error type: value_error)\n", + "\n", + "Past Error (solution, 3): Input should be a valid string: prompt (error type: string_type)\n", + "\n", + "Past Error (solution, 4): Value error, Code is not syntactically valid: unexpected character after line continuation character (, line 1): code (error type: value_error)\n", + "\n", + "Solution:\u001b[32m {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\\\n for i in range(len(numbers)):\\\\n for j in range(i+1, len(numbers)):\\\\n if abs(numbers[i] - numbers[j]) < threshold:\\\\n return True\\\\n return False\"}\u001b[0m\n", "\n", "\n", "\n" @@ -291,7 +335,96 @@ } ], "source": [ - "lm.inspect_history(n=3)" + "lm.inspect_history(n=3)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def has_close_elements(numbers: List[float], threshold: float) -> bool:\n", + " for i in range(len(numbers)):\n", + " for j in range(i+1, len(numbers)):\n", + " if abs(numbers[i] - numbers[j]) < threshold:\n", + " return True\n", + " return False\n" + ] + } + ], + "source": [ + "d = {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\"}\n", + "print(d[\"code\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "JSONDecodeError", + "evalue": "Invalid control character at: line 1 column 82 (char 81)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjson\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m{\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcode\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m: \u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdef has_close_elements(numbers: List[float], threshold: float) -> bool:\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m for i in range(len(numbers)):\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m for j in range(i+1, len(numbers)):\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m if abs(numbers[i] - numbers[j]) < threshold:\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m return True\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m return False\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m}\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 341\u001b[0m s \u001b[38;5;241m=\u001b[39m s\u001b[38;5;241m.\u001b[39mdecode(detect_encoding(s), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msurrogatepass\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m JSONDecoder\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, s, _w\u001b[38;5;241m=\u001b[39mWHITESPACE\u001b[38;5;241m.\u001b[39mmatch):\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03m containing a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m end \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(s):\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/json/decoder.py:353\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 344\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Decode a JSON document from ``s`` (a ``str`` beginning with\u001b[39;00m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;124;03ma JSON document) and return a 2-tuple of the Python\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;124;03mrepresentation and the index in ``s`` where the document ended.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 350\u001b[0m \n\u001b[1;32m 351\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan_once\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m JSONDecodeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpecting value\u001b[39m\u001b[38;5;124m\"\u001b[39m, s, err\u001b[38;5;241m.\u001b[39mvalue) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[0;31mJSONDecodeError\u001b[0m: Invalid control character at: line 1 column 82 (char 81)" + ] + } + ], + "source": [ + "import json\n", + "json.loads('{\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\"}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'code': 'def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False'}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import ujson\n", + "ujson.loads('{\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\"} ')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'code': 'def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False'}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "json.loads(ujson.dumps(ujson.loads('{\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n for i in range(len(numbers)):\\n for j in range(i+1, len(numbers)):\\n if abs(numbers[i] - numbers[j]) < threshold:\\n return True\\n return False\"} ')))" ] }, { @@ -313,7 +446,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -341,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -410,7 +543,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -677,7 +810,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -988,7 +1121,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1207,7 +1340,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/examples/functional/signature_opt_typed.ipynb b/examples/functional/signature_opt_typed.ipynb new file mode 100644 index 0000000000..7447a965ee --- /dev/null +++ b/examples/functional/signature_opt_typed.ipynb @@ -0,0 +1,325 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import os\n", + "from dotenv import load_dotenv\n", + "load_dotenv()\n", + "assert 'OPENAI_API_KEY' in os.environ" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=4000)\n", + "gpt4 = dspy.OpenAI(model='gpt-4', max_tokens=4000)\n", + "dspy.settings.configure(lm=turbo)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Prediction(\n", + " answer='Paris'\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dspy.TypedPredictor(\"question -> answer\")(question=\"What is the capital of France?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(20, 50)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dspy.datasets import HotPotQA\n", + "\n", + "# Load the dataset.\n", + "dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)\n", + "\n", + "# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.\n", + "trainset = [x.with_inputs('question') for x in dataset.train]\n", + "devset = [x.with_inputs('question') for x in dataset.dev]\n", + "\n", + "len(trainset), len(devset)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class BasicQA(dspy.Signature):\n", + " \"\"\"Answer questions with short factoid answers.\"\"\"\n", + "\n", + " question = dspy.InputField()\n", + " answer = dspy.OutputField(desc=\"often between 1 and 5 words\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 typed predictors to optimize.\n", + "Generating 4 initial signatures for base...\n", + "\n", + "================================================================================\n", + "Running eval iteration 0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 16 / 50 (32.0): 100%|██████████| 50/50 [00:00<00:00, 4290.32it/s]\n", + "/Users/ahle/repos/dspy/dspy/evaluate/evaluate.py:142: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", + " df = df.applymap(truncate_cell)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 16 / 50 (32.0%)\n", + "\n", + "================================================================================\n", + "Running eval iteration 1...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 16 / 50 (32.0): 100%|██████████| 50/50 [00:02<00:00, 22.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 16 / 50 (32.0%)\n", + "\n", + "================================================================================\n", + "Running eval iteration 2...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 19 / 50 (38.0): 100%|██████████| 50/50 [00:04<00:00, 10.28it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 19 / 50 (38.0%)\n", + "\n", + "================================================================================\n", + "Running eval iteration 3...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 11 / 50 (22.0): 100%|██████████| 50/50 [00:05<00:00, 8.63it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 11 / 50 (22.0%)\n", + "\n", + "================================================================================\n", + "Running eval iteration 4...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 15 / 50 (30.0): 100%|██████████| 50/50 [00:02<00:00, 24.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 15 / 50 (30.0%)\n", + "Generating new signature for base...\n", + "\n", + "================================================================================\n", + "Running eval iteration 5...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 18 / 50 (36.0): 100%|██████████| 50/50 [00:02<00:00, 21.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 18 / 50 (36.0%)\n", + "Generating new signature for base...\n", + "\n", + "================================================================================\n", + "Running eval iteration 6...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 6 / 50 (12.0): 100%|██████████| 50/50 [00:03<00:00, 13.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 6 / 50 (12.0%)\n", + "Generating new signature for base...\n", + "\n", + "================================================================================\n", + "Running eval iteration 7...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 17 / 50 (34.0): 100%|██████████| 50/50 [00:02<00:00, 19.56it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Metric: 17 / 50 (34.0%)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from dspy.evaluate import Evaluate\n", + "from dspy.evaluate.metrics import answer_exact_match\n", + "from dspy.functional import TypedPredictor\n", + "from dspy.teleprompt.signature_opt_typed import optimize_signature\n", + "\n", + "evaluator = Evaluate(devset=devset, metric=answer_exact_match, num_threads=10, display_progress=True)\n", + "\n", + "program = optimize_signature(\n", + " student=TypedPredictor(BasicQA),\n", + " evaluator=evaluator,\n", + " initial_prompts=4,\n", + " n_iterations=8,\n", + " verbose=True,\n", + " prompt_model=gpt4,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "StringSignature(question -> answer\n", + " instructions='You are highly intelligent. Please provide short, factual answers to the following questions.'\n", + " question = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Inquiry:', 'desc': '${question}'})\n", + " answer = Field(annotation=str required=True json_schema_extra={'desc': 'usually between 1 and 5 words', '__dspy_field_type': 'output', 'prefix': 'Reply:'})\n", + ")\n" + ] + } + ], + "source": [ + "print(program.signature)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py39", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/generation.py b/examples/generation.py new file mode 100644 index 0000000000..4ccb2d8fee --- /dev/null +++ b/examples/generation.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel, Field + +import dspy +from dspy.functional import TypedPredictor +from dspy.teleprompt import LabeledFewShot + +turbo = dspy.OpenAI(model='gpt-3.5-turbo') +colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') +dspy.settings.configure(lm=turbo, rm=colbertv2_wiki17_abstracts) + +class SyntheticFact(BaseModel): + fact: str = Field(..., description="a statement") + varacity: bool = Field(..., description="is the statement true or false") + +class ExampleSignature(dspy.Signature): + """Generate an example of a synthetic fact.""" + fact: SyntheticFact = dspy.OutputField() + +generator = TypedPredictor(ExampleSignature) +examples = generator(config=dict(n=10)) + +# If you have examples and want more +existing_examples = [ + dspy.Example(fact="The sky is blue", varacity=True), + dspy.Example(fact="The sky is green", varacity=False), +] +trained = LabeledFewShot().compile(student=generator, trainset=existing_examples) + +augmented_examples = trained(config=dict(n=10)) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index cabbc7cd09..2289ce697c 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -2,13 +2,13 @@ import textwrap import pydantic from pydantic import Field, BaseModel, field_validator -from typing import Annotated +from typing import Annotated, Generic, Literal, TypeVar from typing import List import pytest import dspy -from dspy.functional import predictor, cot, FunctionalModule, TypedPredictor, functional +from dspy.functional import predictor, cot, FunctionalModule, TypedPredictor, TypedChainOfThought from dspy.primitives.example import Example from dspy.teleprompt.bootstrap import BootstrapFewShot from dspy.teleprompt.vanilla import LabeledFewShot @@ -122,7 +122,7 @@ def forward(self, **kwargs): qa = QA() assert isinstance(qa, FunctionalModule) - assert isinstance(qa.answer, functional._StripOutput) + assert isinstance(qa.answer, dspy.Module) question, answer = qa(topic="Physics") @@ -407,6 +407,7 @@ def get_user_details() -> UserDetails: with pytest.raises(ValueError): get_user_details() + print(lm.get_convo(-1)) assert lm.get_convo(-1) == textwrap.dedent( """\ Given the fields , produce the fields `get_user_details`. @@ -467,6 +468,23 @@ class TestSignature(dspy.Signature): assert output == [0, 1, 2] +def test_multiple_outputs_int_cot(): + # Note: Multiple outputs only work when the language model "speculatively" generates all the outputs in one go. + lm = DummyLM( + [ + "thoughts 0\nOutput: 0\n", + "thoughts 1\nOutput: 1\n", + "thoughts 2\nOutput: 2\n", + ] + ) + dspy.settings.configure(lm=lm) + + test = TypedChainOfThought("input:str -> output:int") + + output = test(input="8", config=dict(n=3)).completions.output + assert output == [0, 1, 2] + + def test_parse_type_string(): lm = DummyLM([str(i) for i in range(100)]) dspy.settings.configure(lm=lm) @@ -477,6 +495,28 @@ def test_parse_type_string(): assert output == [0, 1, 2] +def test_literal(): + lm = DummyLM([f'{{"value": "{i}"}}' for i in range(100)]) + dspy.settings.configure(lm=lm) + + @predictor + def f() -> Literal["2", "3"]: + pass + + assert f() == "2" + + +def test_literal_int(): + lm = DummyLM([f'{{"value": {i}}}' for i in range(100)]) + dspy.settings.configure(lm=lm) + + @predictor + def f() -> Literal[2, 3]: + pass + + assert f() == 2 + + def test_fields_on_base_signature(): class SimpleOutput(dspy.Signature): output: float = dspy.OutputField(gt=0, lt=1) @@ -532,3 +572,66 @@ class ExampleSignature(dspy.Signature): augmented_examples = trained(config=dict(n=3)) for ex in augmented_examples.completions.fact: assert isinstance(ex, SyntheticFact) + + +def test_list_input2(): + # Inspired by the Signature Optimizer + + class ScoredString(pydantic.BaseModel): + string: str + score: float + + class ScoredSignature(dspy.Signature): + attempted_signatures: list[ScoredString] = dspy.InputField() + proposed_signature: str = dspy.OutputField() + + program = TypedChainOfThought(ScoredSignature) + + lm = DummyLM(["Thoughts", "Output"]) + dspy.settings.configure(lm=lm) + + output = program( + attempted_signatures=[ + ScoredString(string="string 1", score=0.5), + ScoredString(string="string 2", score=0.4), + ScoredString(string="string 3", score=0.3), + ] + ).proposed_signature + + print(lm.get_convo(-1)) + + assert output == "Output" + + assert lm.get_convo(-1) == textwrap.dedent("""\ + Given the fields `attempted_signatures`, produce the fields `proposed_signature`. + + --- + + Follow the following format. + + Attempted Signatures: ${attempted_signatures} + Reasoning: Let's think step by step in order to ${produce the proposed_signature}. We ... + Proposed Signature: ${proposed_signature} + + --- + + Attempted Signatures: [{"string":"string 1","score":0.5},{"string":"string 2","score":0.4},{"string":"string 3","score":0.3}] + Reasoning: Let's think step by step in order to Thoughts + Proposed Signature: Output""") + + +def test_generic_signature(): + T = TypeVar("T") + + class GenericSignature(dspy.Signature, Generic[T]): + """My signature""" + + output: T = dspy.OutputField() + + predictor = TypedPredictor(GenericSignature[int]) + assert predictor.signature.instructions == "My signature" + + lm = DummyLM(["23"]) + dspy.settings.configure(lm=lm) + + assert predictor().output == 23 diff --git a/tests/functional/test_signature_opt_typed.py b/tests/functional/test_signature_opt_typed.py new file mode 100644 index 0000000000..44adb9d0ea --- /dev/null +++ b/tests/functional/test_signature_opt_typed.py @@ -0,0 +1,175 @@ +import json +import dspy +from dspy.evaluate import Evaluate +from dspy.functional import TypedPredictor +from dspy.teleprompt.signature_opt_typed import ( + GenerateSignature, + make_info, + optimize_signature, +) +from dspy.utils import DummyLM + +from dspy.evaluate import Evaluate +from dspy.evaluate.metrics import answer_exact_match +from dspy.functional import TypedPredictor + + +class BasicQA(dspy.Signature): + question: str = dspy.InputField() + answer: str = dspy.OutputField() + + +hotpotqa = [ + ex.with_inputs("question") + for ex in [ + dspy.Example( + question="At My Window was released by which American singer-songwriter?", + answer="John Townes Van Zandt", + ), + dspy.Example( + question="which American actor was Candace Kita guest starred with ", + answer="Bill Murray", + ), + dspy.Example( + question="Which of these publications was most recently published, Who Put the Bomp or Self?", + answer="Self", + ), + dspy.Example( + question="The Victorians - Their Story In Pictures is a documentary series written by an author born in what year?", + answer="1950", + ), + dspy.Example( + question="Which magazine has published articles by Scott Shaw, Tae Kwon Do Times or Southwest Art?", + answer="Tae Kwon Do Times", + ), + dspy.Example( + question="In what year was the club founded that played Manchester City in the 1972 FA Charity Shield", + answer="1874", + ), + dspy.Example( + question="Which is taller, the Empire State Building or the Bank of America Tower?", + answer="The Empire State Building", + ), + dspy.Example( + question='Which American actress who made their film debut in the 1995 teen drama "Kids" was the co-founder of Voto Latino?', + answer="Rosario Dawson", + ), + dspy.Example( + question="Tombstone stared an actor born May 17, 1955 known as who?", + answer="Bill Paxton", + ), + dspy.Example( + question="What is the code name for the German offensive that started this Second World War engagement on the Eastern Front (a few hundred kilometers from Moscow) between Soviet and German forces, which included 102nd Infantry Division?", + answer="Operation Citadel", + ), + dspy.Example( + question='Who acted in the shot film The Shore and is also the youngest actress ever to play Ophelia in a Royal Shakespeare Company production of "Hamlet." ?', + answer="Kerry Condon", + ), + dspy.Example( + question="Which company distributed this 1977 American animated film produced by Walt Disney Productions for which Sherman Brothers wrote songs?", + answer="Buena Vista Distribution", + ), + dspy.Example( + question="Samantha Cristoforetti and Mark Shuttleworth are both best known for being first in their field to go where? ", + answer="space", + ), + dspy.Example( + question="Having the combination of excellent foot speed and bat speed helped Eric Davis, create what kind of outfield for the Los Angeles Dodgers? ", + answer="Outfield of Dreams", + ), + dspy.Example( + question="Which Pakistani cricket umpire who won 3 consecutive ICC umpire of the year awards in 2009, 2010, and 2011 will be in the ICC World Twenty20?", + answer="Aleem Sarwar Dar", + ), + dspy.Example( + question="The Organisation that allows a community to influence their operation or use and to enjoy the benefits arisingwas founded in what year?", + answer="2010", + ), + dspy.Example( + question='"Everything Has Changed" is a song from an album released under which record label ?', + answer="Big Machine Records", + ), + dspy.Example( + question="Who is older, Aleksandr Danilovich Aleksandrov or Anatoly Fomenko?", + answer="Aleksandr Danilovich Aleksandrov", + ), + dspy.Example( + question="On the coast of what ocean is the birthplace of Diogal Sakho?", + answer="Atlantic", + ), + dspy.Example( + question="This American guitarist best known for her work with the Iron Maidens is an ancestor of a composer who was known as what?", + answer="The Waltz King", + ), + ] +] + + +def old_test_signature_info(): + info = make_info(BasicQA) + SignatureInfo = type(info) + + devset = [ + dspy.Example( + instructions="Answer the following questions", + question_desc="Some question to answer", + question_prefix="Q: ", + answer_desc="A short answer to the question", + answer_prefix="A: ", + ), + ] + + lm = DummyLM( + [ + json.dumps(dict(devset[0])), # Proposed signature + ] + ) + dspy.settings.configure(lm=lm) + + generator = TypedPredictor(GenerateInstructionGivenAttempts[SignatureInfo]) + + res = generator(attempted_signatures=[ScoredSignature[SignatureInfo](signature=info, score=50)]) + assert res.proposed_signature == SignatureInfo(**devset[0]) + + # Test the "to_signature" method + + class OutputSignature(dspy.Signature): + """Answer the following questions""" + + question: str = dspy.InputField(desc="Some question to answer", prefix="Q: ") + answer: str = dspy.OutputField(desc="A short answer to the question", prefix="A: ") + + assert res.proposed_signature.to_signature().equals(OutputSignature) + + +def test_opt(): + qa_model = DummyLM([]) + prompt_model = DummyLM( + [ + # Seed prompts + "some thoughts", + '{"value": [{"instructions": "I", "question_desc": "$q", "question_prefix": "Q:", "answer_desc": "$a", "answer_prefix": "A:"}]}', + ] + ) + dspy.settings.configure(lm=qa_model) + + program = optimize_signature( + student=TypedPredictor(BasicQA), + evaluator=Evaluate(devset=hotpotqa, metric=answer_exact_match, num_threads=1), + initial_prompts=1, + n_iterations=2, + verbose=True, + prompt_model=prompt_model, + strategy="last", + ) + + # Since we are requesting the last signature, it doesn't matter that our qa_model is + # bad, and gets 0 score. We should still get the last signature. + class ExpectedSignature(dspy.Signature): + "I" + + question: str = dspy.InputField(desc="$q", prefix="Q:") + answer: str = dspy.OutputField(desc="$a", prefix="A:") + + assert program.signature.equals(ExpectedSignature) diff --git a/tests/predict/test_predict.py b/tests/predict/test_predict.py index e44b3a135c..c0407b938c 100644 --- a/tests/predict/test_predict.py +++ b/tests/predict/test_predict.py @@ -1,14 +1,14 @@ import dspy from dspy import Predict, Signature from dspy.utils.dummies import DummyLM +import copy +import textwrap def test_initialization_with_string_signature(): signature_string = "input1, input2 -> output" predict = Predict(signature_string) - expected_instruction = ( - "Given the fields `input1`, `input2`, produce the fields `output`." - ) + expected_instruction = "Given the fields `input1`, `input2`, produce the fields `output`." assert predict.signature.instructions == expected_instruction assert predict.signature.instructions == Signature(signature_string).instructions @@ -89,3 +89,58 @@ def test_multi_output(): results = program(question="What is 1+1?") assert results.completions.answer[0] == "my first answer" assert results.completions.answer[1] == "my second answer" + + +def test_multi_output2(): + program = Predict("question -> answer1, answer2", n=2) + dspy.settings.configure( + lm=DummyLM( + [ + "my 0 answer\nAnswer 2: my 2 answer", + "my 1 answer\nAnswer 2: my 3 answer", + ], + ) + ) + results = program(question="What is 1+1?") + assert results.completions.answer1[0] == "my 0 answer" + assert results.completions.answer1[1] == "my 1 answer" + assert results.completions.answer2[0] == "my 2 answer" + assert results.completions.answer2[1] == "my 3 answer" + + +def test_named_predictors(): + class MyModule(dspy.Module): + def __init__(self): + super().__init__() + self.inner = Predict("question -> answer") + + program = MyModule() + assert program.named_predictors() == [("inner", program.inner)] + + # Check that it also works the second time. + program2 = copy.deepcopy(program) + assert program2.named_predictors() == [("inner", program2.inner)] + + +def test_output_only(): + class OutputOnlySignature(dspy.Signature): + output = dspy.OutputField() + + predictor = Predict(OutputOnlySignature) + + lm = DummyLM(["short answer"]) + dspy.settings.configure(lm=lm) + assert predictor().output == "short answer" + + assert lm.get_convo(-1) == textwrap.dedent("""\ + Given the fields , produce the fields `output`. + + --- + + Follow the following format. + + Output: ${output} + + --- + + Output: short answer""") diff --git a/tests/primitives/test_program.py b/tests/primitives/test_program.py index b1d7c89725..87ce09395f 100644 --- a/tests/primitives/test_program.py +++ b/tests/primitives/test_program.py @@ -1,4 +1,5 @@ import dspy +from dspy.primitives.module import BaseModule from dspy.primitives.program import ( Module, set_attribute_by_name, @@ -19,9 +20,7 @@ def forward(self, question): def test_module_initialization(): module = Module() - assert ( - module._compiled is False - ), "Module _compiled attribute should be False upon initialization" + assert module._compiled is False, "Module _compiled attribute should be False upon initialization" def test_named_predictors(): @@ -29,25 +28,19 @@ def test_named_predictors(): named_preds = module.named_predictors() assert len(named_preds) == 2, "Should identify correct number of Predict instances" names, preds = zip(*named_preds) - assert ( - "predict1" in names and "predict2" in names - ), "Named predictors should include 'predict1' and 'predict2'" + assert "predict1" in names and "predict2" in names, "Named predictors should include 'predict1' and 'predict2'" def test_predictors(): module = HopModule() preds = module.predictors() assert len(preds) == 2, "Should return correct number of Predict instances" - assert all( - isinstance(p, dspy.Predict) for p in preds - ), "All returned items should be instances of PredictMock" + assert all(isinstance(p, dspy.Predict) for p in preds), "All returned items should be instances of PredictMock" def test_forward(): program = HopModule() - dspy.settings.configure( - lm=DummyLM({"What is 1+1?": "let me check", "let me check": "2"}) - ) + dspy.settings.configure(lm=DummyLM({"What is 1+1?": "let me check", "let me check": "2"})) result = program(question="What is 1+1?").answer assert result == "2" @@ -64,3 +57,47 @@ def __init__(self): names, _preds = zip(*named_preds) assert "hop.predict1" in names assert "hop.predict2" in names + + +class SubModule(BaseModule): + pass + + +class AnotherSubModule(BaseModule): + pass + + +def test_empty_module(): + module = BaseModule() + assert list(module.named_sub_modules()) == [("base", module)] + + +def test_single_level(): + module = BaseModule() + module.sub = SubModule() + expected = [("base", module), ("base.sub", module.sub)] + assert list(module.named_sub_modules()) == expected + + +def test_multiple_levels(): + module = BaseModule() + module.sub = SubModule() + module.sub.subsub = SubModule() + expected = [("base", module), ("base.sub", module.sub), ("base.sub.subsub", module.sub.subsub)] + assert list(module.named_sub_modules()) == expected + + +def test_multiple_sub_modules(): + module = BaseModule() + module.sub1 = SubModule() + module.sub2 = SubModule() + expected = [("base", module), ("base.sub1", module.sub1), ("base.sub2", module.sub2)] + assert sorted(list(module.named_sub_modules())) == sorted(expected) + + +def test_non_base_module_attributes(): + module = BaseModule() + module.sub = SubModule() + module.not_a_sub = "Not a BaseModule" + expected = [("base", module), ("base.sub", module.sub)] + assert list(module.named_sub_modules()) == expected diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index 554d9b1274..d0eb899d13 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -1,8 +1,12 @@ +import textwrap import pytest import pydantic from dspy import Signature, infer_prefix, InputField, OutputField from typing import List +import dspy +from dspy.utils.dummies import DummyLM + def test_field_types_and_custom_attributes(): class TestSignature(Signature): @@ -174,3 +178,31 @@ class SubSignature(Signature): assert SubSignature.__name__ == "SubSignature" value = SubSignature(input="test", output="test") assert isinstance(value, SubSignature) + + +def test_multiline_instructions(): + class MySignature(Signature): + """First line + Second line""" + + output = OutputField() + + predictor = dspy.Predict(MySignature) + + lm = DummyLM(["short answer"]) + dspy.settings.configure(lm=lm) + assert predictor().output == "short answer" + + assert lm.get_convo(-1) == textwrap.dedent("""\ + First line + Second line + + --- + + Follow the following format. + + Output: ${output} + + --- + + Output: short answer""")