diff --git a/.gitignore b/.gitignore index c47837feb8..7d9919d9aa 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,9 @@ __pycache__/ *.py[cod] *$py.class +# Vim +*.swp + # Jupyter Notebook .ipynb_checkpoints # notebooks/ @@ -42,4 +45,4 @@ finetuning_ckpts/ .idea assertion.log *.log -*.db \ No newline at end of file +*.db diff --git a/dsp/templates/template_v2.py b/dsp/templates/template_v2.py index 0e1b368854..f61085dd05 100644 --- a/dsp/templates/template_v2.py +++ b/dsp/templates/template_v2.py @@ -91,8 +91,8 @@ def query(self, example: Example, is_demo: bool = False) -> str: if field.input_variable in self.format_handlers: format_handler = self.format_handlers[field.input_variable] else: - def format_handler(x): + assert type(x) == str, f"Need format_handler for {field.input_variable} of type {type(x)}" return " ".join(x.split()) formatted_value = format_handler(example[field.input_variable]) diff --git a/dspy/functional/__init__.py b/dspy/functional/__init__.py new file mode 100644 index 0000000000..11fb1bc4d4 --- /dev/null +++ b/dspy/functional/__init__.py @@ -0,0 +1 @@ +from .functional import cot, predictor, FunctionalModule, TypedPredictor diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py new file mode 100644 index 0000000000..bc10def220 --- /dev/null +++ b/dspy/functional/functional.py @@ -0,0 +1,327 @@ +import inspect, os, openai, dspy, typing, pydantic +from typing import Annotated +import typing +from dsp.templates import passages2text +import json + + +MAX_RETRIES = 3 + + +def predictor(func): + signature = _func_to_signature(func) + return TypedPredictor(signature, chain_of_thought=False, simple_output=True) + + +def cot(func): + signature = _func_to_signature(func) + return TypedPredictor(signature, chain_of_thought=True, simple_output=True) + + +class FunctionalModule(dspy.Module): + def __init__(self): + super().__init__() + for name in dir(self): + attr = getattr(self, name) + if isinstance(attr, dspy.Module): + self.__dict__[name] = attr.copy() + + +class TypedPredictor(dspy.Module): + def __init__(self, signature, chain_of_thought=False, simple_output=False): + super().__init__() + self.signature = signature + self.predictor = dspy.Predict(signature) + self.chain_of_thought = chain_of_thought + self.simple_output = simple_output + + def copy(self): + return TypedPredictor(self.signature, self.chain_of_thought, self.simple_output) + + def _prepare_signature(self): + """Add formats and parsers to the signature fields, based on the type + annotations of the fields.""" + signature = self.signature + for name, field in self.signature.fields.items(): + is_output = field.json_schema_extra["__dspy_field_type"] == "output" + type_ = field.annotation + if is_output: + if type_ in (str, int, float, bool): + signature = signature.with_updated_fields( + name, + desc=field.json_schema_extra.get("desc", "") + + (f". Respond with a single {type_.__name__} value"), + format=lambda x: x if isinstance(x, str) else str(x), + parser=type_, + ) + else: + # Anything else we wrap in a pydantic object + unwrap = lambda x: x + if not inspect.isclass(type_) or not issubclass( + type_, pydantic.BaseModel + ): + type_ = pydantic.create_model( + "Output", value=(type_, ...), __base__=pydantic.BaseModel + ) + unwrap = lambda x: x.value + signature = signature.with_updated_fields( + name, + desc=field.json_schema_extra.get("desc", "") + + ( + f". Respond with a single JSON object using the schema " + + json.dumps(type_.model_json_schema()) + ), + format=lambda x: x if isinstance(x, str) else x.model_dump_json(), + parser=lambda x: unwrap( + type_.model_validate_json(_unwrap_json(x)) + ), + ) + else: # If input field + format = lambda x: x if isinstance(x, str) else str(x) + if type_ in (list[str], tuple[str]): + format = passages2text + elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): + format = lambda x: x if isinstance(x, str) else x.model_dump_json() + signature = signature.with_updated_fields(name, format=format) + + if self.chain_of_thought: + output_keys = ", ".join(signature.output_fields.keys()) + signature = signature.prepend( + "reasoning", + dspy.OutputField( + prefix="Reasoning: Let's think step by step in order to", + desc="${produce the " + output_keys + "}. We ...", + ), + ) + return signature + + def forward(self, **kwargs): + modified_kwargs = kwargs.copy() + signature = self._prepare_signature() + for try_i in range(MAX_RETRIES): + result = self.predictor(**modified_kwargs, new_signature=signature) + errors = {} + parsed_results = {} + # Parse the outputs + for name, field in signature.output_fields.items(): + try: + value = getattr(result, name) + parser = field.json_schema_extra.get("parser", lambda x: x) + parsed_results[name] = parser(value) + except (pydantic.ValidationError, ValueError) as e: + errors[name] = e + if errors: + # Add new fields for each error + for name, error in errors.items(): + modified_kwargs[f"error_{name}_{try_i}"] = str(error) + signature = signature.append( + f"error_{name}_{try_i}", + dspy.InputField( + prefix=f"Past Error " + + (f"({name}):" if try_i == 0 else f"({name}, {try_i+1}):"), + desc="An error to avoid in the future", + ), + ) + else: + # If there are no errors, we return the parsed results + for name, value in parsed_results.items(): + setattr(result, name, value) + if self.simple_output: + *_, last_output = signature.output_fields.keys() + return result[last_output] + return result + raise ValueError("Too many retries") + + +def _func_to_signature(func): + """Make a dspy.Signature based on a function definition.""" + sig = inspect.signature(func) + annotations = typing.get_type_hints(func) + output_key = func.__name__ + instructions = func.__doc__ + fields = {} + + # Input fields + for param in sig.parameters.values(): + if param.name == "self": + continue + # We default to str as the type of the input + annotation = annotations.get(param.name, str) + kwargs = {} + if typing.get_origin(annotation) is Annotated: + annotation, kwargs["desc"] = typing.get_args(annotation) + fields[param.name] = (annotation, dspy.InputField(**kwargs)) + + # Output field + kwargs = {} + annotation = annotations.get("return", str) + if typing.get_origin(annotation) is Annotated: + annotation, kwargs["desc"] = typing.get_args(annotation) + fields[output_key] = (annotation, dspy.OutputField(**kwargs)) + + return dspy.Signature(fields, instructions) + + +def _unwrap_json(output): + output = output.strip() + if output.startswith("```"): + if not output.startswith("```json"): + raise ValueError("json output should start with ```json") + if not output.endswith("```"): + raise ValueError("json output should end with ```") + output = output[7:-3].strip() + if not output.startswith("{") or not output.endswith("}"): + raise ValueError("json output should start and end with { and }") + return output + + +################################################################################ +# Example usage +################################################################################ + + +def main(): + class Answer(pydantic.BaseModel): + value: float + certainty: float + comments: list[str] = pydantic.Field( + description="At least two comments about the answer" + ) + + class QA(dspy.Module): + @predictor + def hard_question(self, topic: str) -> str: + """Think of a hard factual question about a topic. It should be answerable with a number.""" + + @cot + def answer(self, question: Annotated[str, "Question to answer"]) -> Answer: + pass + + def forward(self, **kwargs): + question = self.hard_question(**kwargs) + return (question, self.answer(question=question)) + + openai.api_key = os.getenv("OPENAI_API_KEY") + lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000) + # lm = dspy.OpenAI(model="gpt-4", max_tokens=4000) + # lm = dspy.OpenAI(model="gpt-4-preview-1106", max_tokens=4000) + with dspy.context(lm=lm): + qa = QA() + question, answer = qa(topic="Physics") + # lm.inspect_history(n=5) + + print("Question:", question) + print("Answer:", answer) + + +################################################################################ +# HotpotQA example with SimpleBaleen +################################################################################ + + +def validate_context_and_answer_and_hops(example, pred, trace=None): + if not dspy.evaluate.answer_exact_match(example, pred): + return False + if not dspy.evaluate.answer_passage_match(example, pred): + return False + + hops = [example.question] + [ + outputs.query for *_, outputs in trace if "query" in outputs + ] + + if max([len(h) for h in hops]) > 100: + return False + if any( + dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) + for idx in range(2, len(hops)) + ): + return False + + return True + + +def gold_passages_retrieved(example, pred, trace=None): + gold_titles = set(map(dspy.evaluate.normalize_text, example["gold_titles"])) + found_titles = set( + map(dspy.evaluate.normalize_text, [c.split(" | ")[0] for c in pred.context]) + ) + + return gold_titles.issubset(found_titles) + + +def hotpot(): + from dsp.utils import deduplicate + import dspy.evaluate + from dspy.datasets import HotPotQA + from dspy.evaluate.evaluate import Evaluate + from dspy.teleprompt.bootstrap import BootstrapFewShot + + print("Load the dataset.") + dataset = HotPotQA( + train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0 + ) + trainset = [x.with_inputs("question") for x in dataset.train] + devset = [x.with_inputs("question") for x in dataset.dev] + print("Done") + + class SimplifiedBaleen(FunctionalModule): + def __init__(self, passages_per_hop=3, max_hops=1): + super().__init__() + self.retrieve = dspy.Retrieve(k=passages_per_hop) + self.max_hops = max_hops + + @cot + def generate_query(self, context: list[str], question) -> str: + """Write a simple search query that will help answer a complex question.""" + pass + + @cot + def generate_answer(self, context: list[str], question) -> str: + """Answer questions with short factoid answers.""" + pass + + def forward(self, question): + context = [] + + for hop in range(self.max_hops): + query = self.generate_query(context=context, question=question) + passages = self.retrieve(query).passages + context = deduplicate(context + passages) + + answer = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=answer) + + openai.api_key = os.getenv("OPENAI_API_KEY") + rm = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") + lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000) + dspy.settings.configure(lm=lm, rm=rm, trace=[]) + + evaluate_on_hotpotqa = Evaluate( + devset=devset, num_threads=10, display_progress=True, display_table=5 + ) + + # uncompiled (i.e., zero-shot) program + uncompiled_baleen = SimplifiedBaleen() + print( + "Uncompiled Baleen retrieval score:", + evaluate_on_hotpotqa(uncompiled_baleen, metric=gold_passages_retrieved), + ) + + # compiled (i.e., few-shot) program + compiled_baleen = BootstrapFewShot( + metric=validate_context_and_answer_and_hops + ).compile( + SimplifiedBaleen(), + teacher=SimplifiedBaleen(passages_per_hop=2), + trainset=trainset, + ) + print( + "Compiled Baleen retrieval score:", + evaluate_on_hotpotqa(compiled_baleen, metric=gold_passages_retrieved), + ) + # lm.inspect_history(n=5) + + +if __name__ == "__main__": + # main() + hotpot() diff --git a/dspy/predict/__init__.py b/dspy/predict/__init__.py index f646ab69f6..8b0770150b 100644 --- a/dspy/predict/__init__.py +++ b/dspy/predict/__init__.py @@ -6,3 +6,4 @@ from .aggregation import majority from .program_of_thought import ProgramOfThought from .retry import Retry +from .knn import KNN \ No newline at end of file diff --git a/dspy/predict/aggregation.py b/dspy/predict/aggregation.py index 2212900c2d..ca4154aa2d 100644 --- a/dspy/predict/aggregation.py +++ b/dspy/predict/aggregation.py @@ -26,10 +26,11 @@ def majority(prediction_or_completions, normalize=default_normalize, field=None) except: signature = None - try: - field = field if field else signature.fields[-1].output_variable - except: - field = field if field else list(completions[0].keys())[-1] + if not field: + if signature: + field = signature.output_fields[-1] + else: + field = list(completions[0].keys())[-1] # Normalize normalize = normalize if normalize else lambda x: x @@ -51,5 +52,4 @@ def majority(prediction_or_completions, normalize=default_normalize, field=None) # if input_type == Prediction: return Prediction.from_completions([completion], signature=signature) - return Completions([completion], signature=signature) diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index fceb3b6517..7d50d8a562 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -1,6 +1,7 @@ -import dsp +import dsp, dspy +from dspy.signatures.signature import ensure_signature -from .predict import Predict +from .predict import Predict, signature_to_template # TODO: FIXME: Insert this right before the *first* output field. Also rewrite this to use the new signature system. @@ -33,24 +34,15 @@ def __init__(self, signature, rationale_type=None, activated=True, **config): self.activated = activated - signature = self.signature - *keys, last_key = signature.kwargs.keys() + signature = ensure_signature(self.signature) + *_keys, last_key = signature.output_fields.keys() - DEFAULT_RATIONALE_TYPE = dsp.Type( + rationale_type = rationale_type or dspy.OutputField( prefix="Reasoning: Let's think step by step in order to", desc="${produce the " + last_key + "}. We ...", ) - rationale_type = rationale_type or DEFAULT_RATIONALE_TYPE - - extended_kwargs = {key: signature.kwargs[key] for key in keys} - extended_kwargs.update( - {"rationale": rationale_type, last_key: signature.kwargs[last_key]} - ) - - self.extended_signature = dsp.Template( - signature.instructions, **extended_kwargs - ) + self.extended_signature = signature.prepend("rationale", rationale_type, type_=str) def forward(self, **kwargs): new_signature = kwargs.pop("new_signature", None) @@ -62,7 +54,8 @@ def forward(self, **kwargs): else: signature = self.signature else: - signature = dsp.Template(self.signature.instructions, **new_signature) + signature = new_signature + # template = dsp.Template(self.signature.instructions, **new_signature) return super().forward(signature=signature, **kwargs) diff --git a/dspy/predict/chain_of_thought_with_hint.py b/dspy/predict/chain_of_thought_with_hint.py index b968d0bd95..83d5b5b4b4 100644 --- a/dspy/predict/chain_of_thought_with_hint.py +++ b/dspy/predict/chain_of_thought_with_hint.py @@ -1,4 +1,4 @@ -import dsp +import dsp, dspy from .predict import Predict @@ -9,27 +9,18 @@ class ChainOfThoughtWithHint(Predict): def __init__(self, signature, rationale_type=None, activated=True, **config): super().__init__(signature, **config) - self.activated = activated - signature = self.signature - *keys, last_key = signature.kwargs.keys() - - DEFAULT_HINT_TYPE = dsp.Type(prefix="Hint:", desc="${hint}") - DEFAULT_RATIONALE_TYPE = dsp.Type(prefix="Reasoning: Let's think step by step in order to", - desc="${produce the " + last_key + "}. We ...") + *keys, last_key = signature.fields.keys() + rationale_type = rationale_type or dspy.OutputField( + prefix="Reasoning: Let's think step by step in order to", + desc="${produce the " + last_key + "}. We ...", + ) + self.extended_signature1 = self.signature.insert(-2, "rationale", rationale_type, type_=str) - rationale_type = rationale_type or DEFAULT_RATIONALE_TYPE - - extended_kwargs1 = {key: signature.kwargs[key] for key in keys} - extended_kwargs1.update({'rationale': rationale_type, last_key: signature.kwargs[last_key]}) - - extended_kwargs2 = {key: signature.kwargs[key] for key in keys} - extended_kwargs2.update({'hint': DEFAULT_HINT_TYPE, 'rationale': rationale_type, last_key: signature.kwargs[last_key]}) - - self.extended_signature1 = dsp.Template(signature.instructions, **extended_kwargs1) - self.extended_signature2 = dsp.Template(signature.instructions, **extended_kwargs2) + DEFAULT_HINT_TYPE = dspy.OutputField() + self.extended_signature2 = self.extended_signature1.insert(-2, "hint", DEFAULT_HINT_TYPE, type_=str) def forward(self, **kwargs): signature = self.signature diff --git a/dspy/predict/langchain.py b/dspy/predict/langchain.py index e3ddd37cec..4be855e8db 100644 --- a/dspy/predict/langchain.py +++ b/dspy/predict/langchain.py @@ -13,6 +13,8 @@ from langchain_core.pydantic_v1 import Extra from langchain_core.runnables import Runnable +# TODO: This class is currently hard to test, because it hardcodes gpt-4 usage: +# gpt4T = dspy.OpenAI(model='gpt-4-1106-preview', max_tokens=4000, model_type='chat') class Template2Signature(dspy.Signature): """You are a processor for prompts. I will give you a prompt template (Python f-string) for an arbitrary task for other LMs. diff --git a/dspy/predict/multi_chain_comparison.py b/dspy/predict/multi_chain_comparison.py index 99c2b43c5a..89fc732979 100644 --- a/dspy/predict/multi_chain_comparison.py +++ b/dspy/predict/multi_chain_comparison.py @@ -1,38 +1,55 @@ +import dspy +from dspy.signatures.signature import ensure_signature from .predict import Predict from ..primitives.program import Module import dsp + class MultiChainComparison(Module): def __init__(self, signature, M=3, temperature=0.7, **config): super().__init__() self.M = M - signature = Predict(signature).signature - *keys, last_key = signature.kwargs.keys() + signature = ensure_signature(signature) - extended_kwargs = {key: signature.kwargs[key] for key in keys} + *_, self.last_key = signature.output_fields.keys() for idx in range(M): - candidate_type = dsp.Type(prefix=f"Student Attempt #{idx+1}:", desc="${reasoning attempt}") - extended_kwargs.update({f'reasoning_attempt_{idx+1}': candidate_type}) - - rationale_type = dsp.Type(prefix="Accurate Reasoning: Thank you everyone. Let's now holistically", desc="${corrected reasoning}") - extended_kwargs.update({'rationale': rationale_type, last_key: signature.kwargs[last_key]}) + signature = signature.append( + f"reasoning_attempt_{idx+1}", + dspy.InputField( + prefix=f"Student Attempt #{idx+1}:", desc="${reasoning attempt}" + ), + ) + + signature = signature.prepend( + "rationale", + dspy.OutputField( + prefix="Accurate Reasoning: Thank you everyone. Let's now holistically", + desc="${corrected reasoning}", + ), + ) - signature = dsp.Template(signature.instructions, **extended_kwargs) self.predict = Predict(signature, temperature=temperature, **config) - self.last_key = last_key - + def forward(self, completions, **kwargs): attempts = [] for c in completions: - rationale = c.rationale.strip().split('\n')[0].strip() - answer = c[self.last_key].strip().split('\n')[0].strip() - attempts.append(f"«I'm trying to {rationale} I'm not sure but my prediction is {answer}»") + rationale = c.rationale.strip().split("\n")[0].strip() + answer = c[self.last_key].strip().split("\n")[0].strip() + attempts.append( + f"«I'm trying to {rationale} I'm not sure but my prediction is {answer}»" + ) assert len(attempts) == self.M, len(attempts) - kwargs = {**{f'reasoning_attempt_{idx+1}': attempt for idx, attempt in enumerate(attempts)}, **kwargs} + kwargs = { + **{ + f"reasoning_attempt_{idx+1}": attempt + for idx, attempt in enumerate(attempts) + }, + **kwargs, + } return self.predict(**kwargs) diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index c68ac1a5a8..3823a72a78 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -3,42 +3,16 @@ from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction -from dspy.signatures.field import InputField, OutputField -from dspy.signatures.signature import infer_prefix +from dspy.signatures.signature import ensure_signature, signature_to_template class Predict(Parameter): def __init__(self, signature, **config): self.stage = random.randbytes(8).hex() - self.signature = signature #.signature + self.signature = ensure_signature(signature) self.config = config self.reset() - # if the signature is a string - if isinstance(signature, str): - inputs, outputs = signature.split("->") - inputs, outputs = inputs.split(","), outputs.split(",") - inputs, outputs = [field.strip() for field in inputs], [field.strip() for field in outputs] - - assert all(len(field.split()) == 1 for field in (inputs + outputs)) - - inputs_ = ', '.join([f"`{field}`" for field in inputs]) - outputs_ = ', '.join([f"`{field}`" for field in outputs]) - - instructions = f"""Given the fields {inputs_}, produce the fields {outputs_}.""" - - inputs = {k: InputField() for k in inputs} - outputs = {k: OutputField() for k in outputs} - - for k, v in inputs.items(): - v.finalize(k, infer_prefix(k)) - - for k, v in outputs.items(): - v.finalize(k, infer_prefix(k)) - - self.signature = dsp.Template(instructions, **inputs, **outputs) - - def reset(self): self.lm = None self.traces = [] @@ -51,43 +25,47 @@ def dump_state(self): # Cache the signature instructions and the last field's name. state["signature_instructions"] = self.signature.instructions - state["signature_prefix"] = self.signature.fields[-1].name + + *_, last_key = self.signature.fields.keys() + state["signature_prefix"] = self.signature.fields[last_key].json_schema_extra['prefix'] return state def load_state(self, state): for name, value in state.items(): setattr(self, name, value) - + # Reconstruct the signature. if "signature_instructions" in state: instructions = state["signature_instructions"] - self.signature.instructions = instructions - + self.signature = self.signature.with_instructions(instructions) + if "signature_prefix" in state: prefix = state["signature_prefix"] - self.signature.fields[-1] = self.signature.fields[-1]._replace(name=prefix) - + *_, last_key = self.signature.fields.keys() + self.signature = self.signature.with_updated_fields(last_key, prefix=prefix) + def __call__(self, **kwargs): return self.forward(**kwargs) - + def forward(self, **kwargs): # Extract the three privileged keyword arguments. - new_signature = kwargs.pop("new_signature", None) - signature = kwargs.pop("signature", self.signature) + new_signature = ensure_signature(kwargs.pop("new_signature", None)) + signature = ensure_signature(kwargs.pop("signature", self.signature)) demos = kwargs.pop("demos", self.demos) config = dict(**self.config, **kwargs.pop("config", {})) # Get the right LM to use. lm = kwargs.pop("lm", self.lm) or dsp.settings.lm + assert lm is not None, "No LM is loaded." # If temperature is 0.0 but its n > 1, set temperature to 0.7. temperature = config.get("temperature", None) - temperature = lm.kwargs['temperature'] if temperature is None else temperature + temperature = lm.kwargs["temperature"] if temperature is None else temperature num_generations = config.get("n", None) if num_generations is None: - num_generations = lm.kwargs.get('n', lm.kwargs.get('num_generations', None)) + num_generations = lm.kwargs.get("n", lm.kwargs.get("num_generations", None)) if (temperature is None or temperature <= 0.15) and num_generations > 1: config["temperature"] = 0.7 @@ -98,25 +76,35 @@ def forward(self, **kwargs): x = dsp.Example(demos=demos, **kwargs) if new_signature is not None: - signature = dsp.Template(signature.instructions, **new_signature) + signature = new_signature + + assert all(k in kwargs for k in signature.input_fields), "Not all input fields were provided." + + # Switch to legacy format for dsp.generate + template = signature_to_template(signature) if self.lm is None: - x, C = dsp.generate(signature, **config)(x, stage=self.stage) + x, C = dsp.generate(template, **config)(x, stage=self.stage) else: + # Note: query_only=True means the instructions and examples are not included. + # I'm not really sure why we'd want to do that, but it's there. with dsp.settings.context(lm=self.lm, query_only=True): - # print(f"using lm = {self.lm} !") - x, C = dsp.generate(signature, **config)(x, stage=self.stage) + x, C = dsp.generate(template, **config)(x, stage=self.stage) + + assert self.stage in x, "The generated (input, output) example was not stored" completions = [] for c in C: completions.append({}) - for field in signature.fields: + for field in template.fields: if field.output_variable not in kwargs.keys(): - completions[-1][field.output_variable] = getattr(c, field.output_variable) + completions[-1][field.output_variable] = getattr( + c, field.output_variable + ) pred = Prediction.from_completions(completions, signature=signature) - + if kwargs.pop("_trace", True) and dsp.settings.trace is not None: trace = dsp.settings.trace trace.append((self, {**kwargs}, pred)) @@ -125,7 +113,7 @@ def forward(self, **kwargs): def update_config(self, **kwargs): self.config = {**self.config, **kwargs} - + def get_config(self): return self.config @@ -133,7 +121,6 @@ def __repr__(self): return f"{self.__class__.__name__}({self.signature})" - # TODO: get some defaults during init from the context window? # # TODO: FIXME: Hmm, I guess expected behavior is that contexts can # affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates. diff --git a/dspy/predict/program_of_thought.py b/dspy/predict/program_of_thought.py index 65d1613b3d..516c1129cf 100644 --- a/dspy/predict/program_of_thought.py +++ b/dspy/predict/program_of_thought.py @@ -1,94 +1,159 @@ import dsp import dspy +from dspy.signatures.signature import ensure_signature from ..primitives.program import Module from ..primitives.python_interpreter import CodePrompt, PythonInterpreter import re + class ProgramOfThought(Module): def __init__(self, signature, max_iters=3): super().__init__() - self.signature = signature = dspy.Predict(signature).signature + self.signature = signature = ensure_signature(signature) self.max_iters = max_iters - self.input_fields = signature.input_fields() - self.output_fields = signature.output_fields() + self.input_fields = signature.input_fields + self.output_fields = signature.output_fields - inputs_ = ', '.join([f"`{field_name}`" for field_name in self.input_fields.keys()]) - outputs_ = ', '.join([f"`{field_name}`" for field_name in self.output_fields.keys()]) + inputs_ = ", ".join( + [f"`{field_name}`" for field_name in self.input_fields.keys()] + ) + outputs_ = ", ".join( + [f"`{field_name}`" for field_name in self.output_fields.keys()] + ) assert len(self.output_fields) == 1, "PoT only supports one output field." - + instr = [] - instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.") - instr.append(f"Generating executable Python code that programmatically computes the correct {outputs_}.") - instr.append(f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.") - instr = '\n'.join(instr) - - self.code_generate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('generate'), **self._generate_signature('generate'))) - self.code_regenerate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('regenerate'), **self._generate_signature('regenerate'))) - self.generate_answer = dspy.ChainOfThought(dsp.Template(self._generate_instruction('answer'), **self._generate_signature('answer'))) + instr.append( + f"You will be given {inputs_} and you will respond with {outputs_}." + ) + instr.append( + f"Generating executable Python code that programmatically computes the correct {outputs_}." + ) + instr.append( + f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}." + ) + instr = "\n".join(instr) + + self.code_generate = dspy.ChainOfThought( + dspy.Signature( + self._generate_signature("generate").fields, + self._generate_instruction("generate"), + ) + ) + self.code_regenerate = dspy.ChainOfThought( + dspy.Signature( + self._generate_signature("regenerate").fields, + self._generate_instruction("regenerate"), + ) + ) + self.generate_answer = dspy.ChainOfThought( + dspy.Signature( + self._generate_signature("answer").fields, + self._generate_instruction("answer"), + ) + ) def _generate_signature(self, mode): signature_dict = dict(self.input_fields) fields_for_mode = { - 'generate': { - 'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str) + "generate": { + "generated_code": dspy.OutputField( + prefix="Code:", + desc="python code that answers the question", + format=str, + ) + }, + "regenerate": { + "previous_code": dspy.InputField( + prefix="Previous Code:", + desc="previously-generated python code that errored", + format=str, + ), + "error": dspy.InputField( + prefix="Error:", + desc="error message from previously-generated python code", + ), + "generated_code": dspy.OutputField( + prefix="Code:", + desc="python code that answers the question", + format=str, + ), }, - 'regenerate': { - 'previous_code': dspy.InputField(prefix="Previous Code:", desc="previously-generated python code that errored", format=str), - 'error': dspy.InputField(prefix="Error:", desc="error message from previously-generated python code"), - 'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str) + "answer": { + "final_generated_code": dspy.InputField( + prefix="Code:", + desc="python code that answers the question", + format=str, + ), + "code_output": dspy.InputField( + prefix="Code Output:", + desc="output of previously-generated python code", + ), + "answer": self.signature.fields["answer"], }, - 'answer': { - 'final_generated_code': dspy.InputField(prefix="Code:", desc="python code that answers the question", format=str), - 'code_output': dspy.InputField(prefix="Code Output:", desc="output of previously-generated python code"), - 'answer': self.signature.kwargs["answer"] - } } signature_dict.update(fields_for_mode[mode]) - return signature_dict + return dspy.Signature(signature_dict) def _generate_instruction(self, mode): - mode_inputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.InputField)]) - mode_outputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.OutputField)]) - if mode == 'generate': + mode_inputs = ", ".join( + [ + f"`{field_name}`" + for field_name in self._generate_signature(mode).input_fields + ] + ) + mode_outputs = ", ".join( + [ + f"`{field_name}`" + for field_name in self._generate_signature(mode).output_fields + ] + ) + if mode == "generate": instr = [ f"You will be given {mode_inputs} and you will respond with {mode_outputs}.", f"Generating executable Python code that programmatically computes the correct {mode_outputs}.", - f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}." + f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}.", ] - elif mode == 'regenerate': + elif mode == "regenerate": instr = [ f"You are given {mode_inputs} due to an error in previous code.", - f"Your task is to correct the error and provide the new {mode_outputs}." + f"Your task is to correct the error and provide the new {mode_outputs}.", ] else: # mode == 'answer' instr = [ f"Given the final code {mode_inputs}, provide the final {mode_outputs}." ] - return '\n'.join(instr) + return "\n".join(instr) def parse_code(self, code_data): - code = code_data.get('generated_code', '').split('---', 1)[0].split('\n\n\n', 1)[0] - code_match = re.search(r'```python[ \n](.*?)[ \n]```?', code, re.DOTALL) - code_block = (code_match.group(1) if code_match else code).replace('\\n', '\n') + code = ( + code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0] + ) + code_match = re.search(r"```python[ \n](.*?)[ \n]```?", code, re.DOTALL) + code_block = (code_match.group(1) if code_match else code).replace("\\n", "\n") if not code_block: return code, "Error: Empty code after parsing." - if "\n" not in code_block and code_block.count('=') > 1: + if "\n" not in code_block and code_block.count("=") > 1: return code, "Error: Code format is not correct." - lines = code_block.split('\n') - last_line_match = re.match(r'^(\w+)\s*=', lines[-1].strip()) + lines = code_block.split("\n") + last_line_match = re.match(r"^(\w+)\s*=", lines[-1].strip()) if last_line_match and len(lines) > 1: - code_block += '\n' + last_line_match.group(1) + code_block += "\n" + last_line_match.group(1) else: - code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)', r'\1\n', code_block) - code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$', r'\1\n\2', code_block) + code_block = re.sub( + r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)", r"\1\n", code_block + ) + code_block = re.sub( + r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$", r"\1\n\2", code_block + ) return code_block, None def execute_code(self, code): if not code: - return code, None, 'Error: Empty code before execution.' + return code, None, "Error: Empty code before execution." code_prompt = CodePrompt(code, code_type="python") interpreter = PythonInterpreter(action_space={"print": print}) try: @@ -96,19 +161,26 @@ def execute_code(self, code): return code, output, None except Exception as e: return code, None, str(e) - + def forward(self, **kwargs): code_data = self.code_generate(question=kwargs["question"]) parsed_code, error = self.parse_code(code_data) + # FIXME: Don't try to execute the code if it didn't parse code, output, error = self.execute_code(parsed_code) hop = 0 while hop < self.max_iters and error: - print('Error in code execution') - code_data = self.code_regenerate(question=kwargs["question"], previous_code=code, error=error) + print("Error in code execution") + code_data = self.code_regenerate( + question=kwargs["question"], previous_code=code, error=error + ) parsed_code, error = self.parse_code(code_data) + # FIXME: Don't try to execute the code if it didn't parse + code, output, error = self.execute_code(parsed_code) hop += 1 if hop == self.max_iters: - print('Max hops reached. Error persists.') + print("Max hops reached. Error persists.") return None - answer_gen_result = self.generate_answer(question=kwargs["question"], final_generated_code=code, code_output=output) + answer_gen_result = self.generate_answer( + question=kwargs["question"], final_generated_code=code, code_output=output + ) return answer_gen_result diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 7dc3d1bd94..ef24c6aca0 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -1,5 +1,6 @@ import dsp import dspy +from dspy.signatures.signature import ensure_signature from ..primitives.program import Module from .predict import Predict @@ -10,35 +11,43 @@ class ReAct(Module): def __init__(self, signature, max_iters=5, num_results=3, tools=None): super().__init__() - self.signature = signature = dspy.Predict(signature).signature + self.signature = signature = ensure_signature(signature) self.max_iters = max_iters self.tools = tools or [dspy.Retrieve(k=num_results)] - self.tools = {tool.name: tool for tool in self.tools} #if isinstance(self.tools, list) else self.tools + self.tools = {tool.name: tool for tool in self.tools} - self.input_fields = {k: v for k, v in self.signature.kwargs.items() if isinstance(v, dspy.InputField)} - self.output_fields = {k: v for k, v in self.signature.kwargs.items() if isinstance(v, dspy.OutputField)} + self.input_fields = self.signature.input_fields + self.output_fields = self.signature.output_fields - inputs, outputs = signature.fields[:-1], signature.fields[-1:] + assert len(self.output_fields) == 1, "ReAct only supports one output field." - inputs_ = ', '.join([f"`{field.input_variable}`" for field in inputs]) - outputs_ = ', '.join([f"`{field.output_variable}`" for field in outputs]) + inputs_ = ", ".join([f"`{k}`" for k in self.input_fields.keys()]) + outputs_ = ", ".join([f"`{k}`" for k in self.output_fields.keys()]) - assert len(outputs) == 1, "ReAct only supports one output field." + instr = [ + f"You will be given {inputs_} and you will respond with {outputs_}.\n", + "To do this, you will interleave Thought, Action, and Observation steps.\n", + "Thought can reason about the current situation, and Action can be the following types:\n", + ] - instr = [] - instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.\n") - instr.append("To do this, you will interleave Thought, Action, and Observation steps.\n") - instr.append("Thought can reason about the current situation, and Action can be the following types:\n") - - self.tools['Finish'] = dspy.Example(name="Finish", input_variable=outputs_.strip('`'), desc=f"returns the final {outputs_} and finishes the task") + self.tools["Finish"] = dspy.Example( + name="Finish", + input_variable=outputs_.strip("`"), + desc=f"returns the final {outputs_} and finishes the task", + ) for idx, tool in enumerate(self.tools): tool = self.tools[tool] - instr.append(f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}") - - instr = '\n'.join(instr) - self.react = [Predict(dsp.Template(instr, **self._generate_signature(i))) for i in range(1, max_iters + 1)] + instr.append( + f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}" + ) + + instr = "\n".join(instr) + self.react = [ + Predict(dspy.Signature(self._generate_signature(i), instr)) + for i in range(1, max_iters + 1) + ] def _generate_signature(self, iters): signature_dict = {} @@ -46,25 +55,42 @@ def _generate_signature(self, iters): signature_dict[key] = val for j in range(1, iters + 1): - signature_dict[f"Thought_{j}"] = dspy.OutputField(prefix=f"Thought {j}:", desc="next steps to take based on last observation") - - tool_list = ' or '.join([f"{tool.name}[{tool.input_variable}]" for tool in self.tools.values() if tool.name != 'Finish']) - signature_dict[f"Action_{j}"] = dspy.OutputField(prefix=f"Action {j}:", desc=f"always either {tool_list} or, when done, Finish[answer]") + signature_dict[f"Thought_{j}"] = dspy.OutputField( + prefix=f"Thought {j}:", + desc="next steps to take based on last observation", + ) + + tool_list = " or ".join( + [ + f"{tool.name}[{tool.input_variable}]" + for tool in self.tools.values() + if tool.name != "Finish" + ] + ) + signature_dict[f"Action_{j}"] = dspy.OutputField( + prefix=f"Action {j}:", + desc=f"always either {tool_list} or, when done, Finish[answer]", + ) if j < iters: - signature_dict[f"Observation_{j}"] = dspy.OutputField(prefix=f"Observation {j}:", desc="observations based on action", format=dsp.passages2text) + signature_dict[f"Observation_{j}"] = dspy.OutputField( + prefix=f"Observation {j}:", + desc="observations based on action", + format=dsp.passages2text, + ) return signature_dict - + def act(self, output, hop): try: action = output[f"Action_{hop+1}"] - action_name, action_val = action.strip().split('\n')[0].split('[', 1) - action_val = action_val.rsplit(']', 1)[0] + action_name, action_val = action.strip().split("\n")[0].split("[", 1) + action_val = action_val.rsplit("]", 1)[0] - if action_name == 'Finish': return action_val + if action_name == "Finish": + return action_val - try: + try: output[f"Observation_{hop+1}"] = self.tools[action_name](action_val).passages except AttributeError: # Handle the case where 'passages' attribute is missing @@ -72,8 +98,10 @@ def act(self, output, hop): output[f"Observation_{hop+1}"] = self.tools[action_name](action_val) except Exception as e: - output[f"Observation_{hop+1}"] = "Failed to parse action. Bad formatting or incorrect action name." - + output[f"Observation_{hop+1}"] = ( + "Failed to parse action. Bad formatting or incorrect action name." + ) + raise e def forward(self, **kwargs): args = {key: kwargs[key] for key in self.input_fields.keys() if key in kwargs} @@ -81,9 +109,10 @@ def forward(self, **kwargs): for hop in range(self.max_iters): # with dspy.settings.context(show_guidelines=(i <= 2)): output = self.react[hop](**args) - - if action_val := self.act(output, hop): break + + if action_val := self.act(output, hop): + break args.update(output) # assumes only 1 output field for now - TODO: handling for multiple output fields - return dspy.Prediction(**{list(self.output_fields.keys())[0]: action_val or ''}) + return dspy.Prediction(**{list(self.output_fields.keys())[0]: action_val or ""}) diff --git a/dspy/predict/retry.py b/dspy/predict/retry.py index b8f06633bf..af1d37f98b 100644 --- a/dspy/predict/retry.py +++ b/dspy/predict/retry.py @@ -9,41 +9,36 @@ class Retry(Predict): def __init__(self, module): super().__init__(module.signature) self.module = module - self.original_signature = module.signature.signature + self.original_signature = module.signature self.original_forward = module.forward self.new_signature = self._create_new_signature(self.original_signature) - def _create_new_signature(self, original_signature): - extended_signature = {} - input_fields = original_signature.input_fields() - output_fields = original_signature.output_fields() - modified_output_fields = {} - - for key, value in output_fields.items(): - modified_output_fields[f"past_{key}"] = dspy.InputField( - prefix="Past " + value.prefix, + def _create_new_signature(self, signature): + # Add "Past" input fields for each output field + for key, value in signature.output_fields.items(): + signature = signature.append(f"past_{key}", dspy.InputField( + prefix="Past " + value.json_schema_extra["prefix"], desc="past output with errors", - format=value.format, - ) - - extended_signature.update(input_fields) - extended_signature.update(modified_output_fields) + format=value.json_schema_extra.get("format"), + )) - extended_signature["feedback"] = dspy.InputField( + signature = signature.append("feedback", dspy.InputField( prefix="Instructions:", desc="Some instructions you must satisfy", format=str, - ) - extended_signature.update(output_fields) + )) - return extended_signature + return signature - def forward(self, *args, **kwargs): - for key, value in kwargs["past_outputs"].items(): + def forward(self, *, past_outputs, **kwargs): + # Convert the dict past_outputs={"answer": ...} to kwargs + # {past_answer=..., ...} + for key, value in past_outputs.items(): past_key = f"past_{key}" - if past_key in self.new_signature: + if past_key in self.new_signature.input_fields: kwargs[past_key] = value - del kwargs["past_outputs"] + # Tell the wrapped module to use the new signature. + # Note: This only works if the wrapped module is a Predict or ChainOfThought. kwargs["new_signature"] = self.new_signature return self.original_forward(**kwargs) diff --git a/dspy/primitives/assertions.py b/dspy/primitives/assertions.py index 5f89896e24..feb54005c3 100644 --- a/dspy/primitives/assertions.py +++ b/dspy/primitives/assertions.py @@ -238,7 +238,6 @@ def wrapper(*args, **kwargs): else: try: dsp.settings.trace.clear() - # print("backtrack", dspy.settings.backtrack_to) result = func(*args, **kwargs) break except (DSPySuggestionError, DSPyAssertionError) as e: @@ -282,13 +281,13 @@ def wrapper(*args, **kwargs): dspy.settings.backtrack_to ].append(error_msg) - output_fields = vars(error_state[0].signature.signature) + # assert isinstance(error_state[0].signature, dspy.Signature) + output_fields = error_state[0].signature.output_fields past_outputs = {} - for field_name, field_obj in output_fields.items(): - if isinstance(field_obj, dspy.OutputField): - past_outputs[field_name] = getattr( - error_state[2], field_name, None - ) + for field_name in output_fields.keys(): + past_outputs[field_name] = getattr( + error_state[2], field_name, None + ) # save latest failure trace for predictor per suggestion error_ip = error_state[1] diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index f05ec01115..11ae2795a7 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -14,6 +14,7 @@ import ast import difflib import importlib +import re import typing import inspect from typing import ( @@ -506,10 +507,11 @@ class TextPrompt(str): @property def key_words(self) -> Set[str]: - r"""Returns a set of strings representing the keywords in the prompt. - """ - from camel.utils import get_prompt_template_key_words - return get_prompt_template_key_words(self) + """Returns a set of strings representing the keywords in the prompt.""" + # Regex to find format placeholders within the string, excluding escaped braces + pattern = re.compile(r"\{([^{}]+)\}") + found = pattern.findall(self) + return set(found) def format(self, *args: Any, **kwargs: Any) -> 'TextPrompt': r"""Overrides the built-in :obj:`str.format` method to allow for diff --git a/dspy/signatures/field.py b/dspy/signatures/field.py index 848439b6d4..7822c625c0 100644 --- a/dspy/signatures/field.py +++ b/dspy/signatures/field.py @@ -1,31 +1,71 @@ -import re -import dsp +import pydantic -class Field: + +def move_kwargs(**kwargs): + # Pydantic doesn't allow arbitrary arguments to be given to fields, + # but asks that + # > any extra data you want to add to the JSON schema should be passed + # > as a dictionary to the json_schema_extra keyword argument. + # See: https://docs.pydantic.dev/2.6/migration/#changes-to-pydanticfield + pydantic_kwargs = {} + json_schema_extra = {} + for k, v in kwargs.items(): + if k in ["desc", "prefix", "format", "parser", "__dspy_field_type"]: + json_schema_extra[k] = v + else: + pydantic_kwargs[k] = v + pydantic_kwargs["json_schema_extra"] = json_schema_extra + return pydantic_kwargs + + +def InputField(**kwargs): + return pydantic.Field(**move_kwargs(**kwargs, __dspy_field_type="input")) + + +def OutputField(**kwargs): + return pydantic.Field(**move_kwargs(**kwargs, __dspy_field_type="output")) + + +def new_to_old_field(field): + return ( + OldInputField + if field.json_schema_extra["__dspy_field_type"] == "input" + else OldOutputField + )( + prefix=field.json_schema_extra["prefix"], + desc=field.json_schema_extra["desc"], + format=field.json_schema_extra.get("format"), + ) + + +class OldField: """A more ergonomic datatype that infers prefix and desc if omitted.""" + def __init__(self, *, prefix=None, desc=None, input, format=None): self.prefix = prefix # This can be None initially and set later self.desc = desc self.format = format - + def finalize(self, key, inferred_prefix): """Set the prefix if it's not provided explicitly.""" if self.prefix is None: self.prefix = inferred_prefix + ":" - + if self.desc is None: - self.desc = f'${{{key}}}' - + self.desc = f"${{{key}}}" + def __repr__(self): return f"{self.__class__.__name__}(prefix={self.prefix}, desc={self.desc})" - + def __eq__(self, __value: object) -> bool: return self.__dict__ == __value.__dict__ -class InputField(Field): + +class OldInputField(OldField): def __init__(self, *, prefix=None, desc=None, format=None): super().__init__(prefix=prefix, desc=desc, input=True, format=format) -class OutputField(Field): + +class OldOutputField(OldField): def __init__(self, *, prefix=None, desc=None, format=None): super().__init__(prefix=prefix, desc=desc, input=False, format=format) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index ea31a46d4e..b73c956d78 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,176 +1,256 @@ -import re +from copy import deepcopy import dsp +from pydantic import BaseModel, Field, create_model +from typing import Type, Union, Dict, Tuple +import re -from .field import Field, InputField, OutputField -import threading +from dspy.signatures.field import InputField, OutputField, new_to_old_field -class SignatureMeta(type): - _thread_local_storage = threading.local() - class _SignatureNamespace: - def __init__(self, fields): - for key, value in fields.items(): - setattr(self, key, value) +def signature_to_template(signature): + """Convert from new to legacy format""" + return dsp.Template( + signature.instructions, + **{name: new_to_old_field(field) for name, field in signature.fields.items()}, + ) - def input_fields(self): - return {k: v for k, v in self.__dict__.items() if isinstance(v, InputField)} - def output_fields(self): - return {k: v for k, v in self.__dict__.items() if isinstance(v, OutputField)} - +def _default_instructions(cls): + inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields.keys()]) + outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields.keys()]) + return f"Given the fields {inputs_}, produce the fields {outputs_}." - def __new__(cls, name, bases, class_dict): - type_attributes = {} - for k, v in list(class_dict.items()): - if isinstance(v, Field): - v.finalize(k, infer_prefix(k)) - type_attributes[k] = v - del class_dict[k] +class SignatureMeta(type(BaseModel)): + def __new__(mcs, name, bases, namespace, **kwargs): + # Set `str` as the default type for all fields + raw_annotations = namespace.get("__annotations__", {}) + for name, field in namespace.items(): + if not name.startswith("__") and name not in raw_annotations: + raw_annotations[name] = str + namespace["__annotations__"] = raw_annotations - instructions = class_dict.get('__doc__') or "" + # Let Pydantic do its thing + cls = super().__new__(mcs, name, bases, namespace, **kwargs) - new_class = super().__new__(cls, name, bases, class_dict) + if cls.__doc__ is None: + cls.__doc__ = _default_instructions(cls) - # Attach the _SignatureNamespace directly to the class - setattr(new_class, 'signature', cls._SignatureNamespace(type_attributes)) + # Ensure all fields are declared with InputField or OutputField + cls._validate_fields() - # Create and attach the template directly to the class - setattr(new_class, '_template', dsp.Template(instructions=instructions, **type_attributes)) + # Ensure all fields have a prefix + for name, field in cls.model_fields.items(): + if "prefix" not in field.json_schema_extra: + field.json_schema_extra["prefix"] = infer_prefix(name) + ":" + if "desc" not in field.json_schema_extra: + field.json_schema_extra["desc"] = f"${{{name}}}" - return new_class + return cls + + def _validate_fields(cls): + for name, field in cls.model_fields.items(): + extra = field.json_schema_extra or {} + field_type = extra.get("__dspy_field_type") + if field_type not in ["input", "output"]: + raise TypeError( + f"Field '{name}' in '{cls.__name__}' must be declared with InputField or OutputField." + ) @property - def kwargs(cls): - return cls.signature.fields - - def __call__(cls, *args, **kwargs): - if len(args) == 1 and isinstance(args[0], str): - instance = super(SignatureMeta, cls).__call__(*args, **kwargs) - return instance - #old - return cls._template(*args, **kwargs) - - def __getattr__(cls, attr): - # Redirect attribute access to the template object when accessed on the class directly - if attr not in cls.__dict__: - return getattr(cls._template, attr) - return super().__getattr__(attr) - -class Signature(metaclass=SignatureMeta): - def __init__(self, signature: str = "", instructions: str = ""): - self.signature = signature - self.instructions = instructions - self.fields = {} - self.parse_structure() - - def __getattr__(self, attr): - if attr not in self.__dict__: - return getattr(self.__class__, attr) - return super().__getattr__(attr) + def signature(cls) -> str: + in_args = ", ".join(cls.input_fields.keys()) + out_args = ", ".join(cls.output_fields.keys()) + return f"{in_args} -> {out_args}" @property - def kwargs(self): - return {k: v for k, v in self.fields.items()} - - def parse_structure(self): - inputs_str, outputs_str = self.signature.split("->") - for name in inputs_str.split(","): - self.add_field(name.strip(), InputField()) - for name in outputs_str.split(","): - self.add_field(name.strip(), OutputField()) - - def attach(self, **kwargs): - for key, (prefix, desc) in kwargs.items(): - field_type = self.fields.get(key) - if not field_type: - raise ValueError(f"{key} does not exist in this signature") - field_map = { - InputField: InputField(prefix=prefix, desc=desc), - OutputField: OutputField(prefix=prefix, desc=desc) - } - self.fields[key] = field_map.get(type(field_type)) - return self - - def add_field(self, field_name: str, field_type, position="append"): - if field_name in self.fields: - raise ValueError(f"{field_name} already exists in fields.") - if isinstance(field_type, (InputField, OutputField)): - field_instance = field_type - else: - raise ValueError(f"non-existent {field_type}.") - if isinstance(field_instance, InputField) and position == "append": - input_fields = self.input_fields() - if input_fields: - last_input_key = list(input_fields.keys())[-1] - index = list(self.fields.keys()).index(last_input_key) + 1 - self.fields = {**dict(list(self.fields.items())[:index]), field_name: field_instance, **dict(list(self.fields.items())[index:])} - else: - self.fields[field_name] = field_instance - elif isinstance(field_instance, OutputField) and position == "prepend": - output_fields = self.output_fields() - if output_fields: - first_output_key = list(output_fields.keys())[0] - index = list(self.fields.keys()).index(first_output_key) - self.fields = {**dict(list(self.fields.items())[:index]), field_name: field_instance, **dict(list(self.fields.items())[index:])} - else: - self.fields[field_name] = field_instance - elif position == "prepend": - self.fields = {field_name: field_instance, **self.fields} - elif position == "append": - self.fields[field_name] = field_instance + def instructions(cls) -> str: + return getattr(cls, "__doc__", "") + + def with_instructions(cls, instructions: str): + return create_model( + cls.__name__, __base__=Signature, __doc__=instructions, **cls.fields + ) + + @property + def fields(cls): + # Make sure to give input fields before output fields + return {**cls.input_fields, **cls.output_fields} + + def with_updated_fields(cls, name, **kwargs): + """Returns a new Signature type with the field, name, updated + with fields[name].json_schema_extra[key] = value.""" + fields_copy = deepcopy(cls.fields) + fields_copy[name].json_schema_extra = { + **fields_copy[name].json_schema_extra, + **kwargs, + } + return create_model( + cls.__name__, __base__=Signature, __doc__=cls.instructions, **fields_copy + ) + + @property + def input_fields(cls): + return cls._get_fields_with_type("input") + + @property + def output_fields(cls): + return cls._get_fields_with_type("output") + + def _get_fields_with_type(cls, field_type): + return { + k: v + for k, v in cls.model_fields.items() + if v.json_schema_extra["__dspy_field_type"] == field_type + } + + def prepend(cls, name, field, type_=None): + return cls.insert(0, name, field, type_) + + def append(cls, name, field, type_=None): + return cls.insert(-1, name, field, type_) + + def insert(cls, index: int, name: str, field, type_: Type = None): + # It's posisble to set the type as annotation=type in pydantic.Field(...) + # But this may be annoying for users, so we allow them to pass the type + if type_ is not None: + field.annotation = type_ + + input_fields = list(cls.input_fields.items()) + output_fields = list(cls.output_fields.items()) + + # Choose the list to insert into based on the field type + lst = ( + input_fields + if field.json_schema_extra["__dspy_field_type"] == "input" + else output_fields + ) + # We support negative insert indices + if index < 0: + index += len(lst) + 1 + if index < 0 or index > len(lst): + raise ValueError(f"Invalid index: {index}") + lst.insert(index, (name, field)) + + new_fields = dict(input_fields + output_fields) + new_signature = create_model( + cls.__name__ + "'", __base__=Signature, **new_fields + ) + new_signature.__doc__ = cls.instructions + return new_signature + + def _parse_signature(cls, signature: str) -> Tuple[Type, Field]: + pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$" + if not re.match(pattern, signature): + raise ValueError(f"Invalid signature format: '{signature}'") + + fields = {} + inputs_str, outputs_str = map(str.strip, signature.split("->")) + inputs = [v.strip() for v in inputs_str.split(",") if v.strip()] + outputs = [v.strip() for v in outputs_str.split(",") if v.strip()] + for name in inputs: + fields[name] = (str, InputField()) + for name in outputs: + fields[name] = (str, OutputField()) + + return fields + + def __call__( + cls, + signature: Union[str, Dict[str, Tuple[type, Field]]], + instructions: str = None, + ): + """ + Creates a new Signature type with the given fields and instructions. + Note: + Even though we're calling a type, we're not making an instance of the type. + In general we don't allow instances of Signature types to be made. The call + syntax is only for your convenience. + Parameters: + signature: Format: "input1, input2 -> output1, output2" + instructions: Optional prompt for the signature. + """ + + if isinstance(signature, str): + fields = cls._parse_signature(signature) else: - raise ValueError(f"invalid field addition. Please verify that your field name: {field_name}, field_type: {field_type}, and expected position: {position} are correct.") + fields = signature + + # Default prompt when no instructions are provided + if instructions is None: + sig = Signature(signature, "") # Simple way to parse input/output fields + instructions = _default_instructions(sig) + + signature = create_model("Signature", __base__=Signature, **fields) + signature.__doc__ = instructions + return signature - def input_fields(self): - return {k: v for k, v in self.fields.items() if isinstance(v, InputField)} + def equals(cls, other): + """Compare the JSON schema of two Pydantic models.""" + if not isinstance(other, type) or not issubclass(other, BaseModel): + return False + if cls.instructions != other.instructions: + return False + for name in cls.fields.keys() | other.fields.keys(): + if name not in other.fields or name not in cls.fields: + return False + # TODO: Should we compare the fields? + return True - def output_fields(self): - return {k: v for k, v in self.fields.items() if isinstance(v, OutputField)} + def __repr__(cls): + """ + Outputs something on the form: + Signature(question, context -> answer + question: str = InputField(desc="..."), + context: List[str] = InputField(desc="..."), + answer: int = OutputField(desc="..."), + ) + """ + field_reprs = [] + for name, field in cls.fields.items(): + field_reprs.append(f"{name} = Field({field})") + field_repr = "\n ".join(field_reprs) + return ( + f"Signature({cls.signature}\n" + f" instructions={repr(cls.instructions)}\n" + f" {field_repr}\n)" + ) - def __repr__(self): - s = [] - for name, _ in self.fields.items(): - value = getattr(self, name, None) - if value: - s.append(f"- {name} = {value}") - else: - s.append(f"- {name} = [field not attached]") - return f'{self.__class__.__name__}\n' + '\n'.join(s) - def __eq__(self, __value: object) -> bool: - return self._template == __value._template +class Signature(BaseModel, metaclass=SignatureMeta): + pass +def ensure_signature(signature): + if signature is None: + return None + if isinstance(signature, str): + return Signature(signature) + return signature + def infer_prefix(attribute_name: str) -> str: """Infers a prefix from an attribute name.""" - + # Convert camelCase to snake_case, but handle sequences of capital letters properly - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', attribute_name) - intermediate_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1) + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", attribute_name) + intermediate_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1) # Insert underscores around numbers to ensure spaces in the final output - with_underscores_around_numbers = re.sub('([a-zA-Z])(\d)', r'\1_\2', intermediate_name) - with_underscores_around_numbers = re.sub('(\d)([a-zA-Z])', r'\1_\2', with_underscores_around_numbers) + with_underscores_around_numbers = re.sub( + r"([a-zA-Z])(\d)", r"\1_\2", intermediate_name + ) + with_underscores_around_numbers = re.sub( + r"(\d)([a-zA-Z])", r"\1_\2", with_underscores_around_numbers + ) # Convert snake_case to 'Proper Title Case', but ensure acronyms are uppercased - words = with_underscores_around_numbers.split('_') + words = with_underscores_around_numbers.split("_") title_cased_words = [] for word in words: if word.isupper(): title_cased_words.append(word) else: title_cased_words.append(word.capitalize()) - - return ' '.join(title_cased_words) - -### Testing the function -assert infer_prefix('someAttributeName42IsCool') == 'Some Attribute Name 42 Is Cool' -assert infer_prefix('version2Update') == 'Version 2 Update' -assert infer_prefix('modelT45Enhanced') == 'Model T 45 Enhanced' -assert infer_prefix('someAttributeName') == 'Some Attribute Name' -assert infer_prefix('some_attribute_name') == 'Some Attribute Name' -assert infer_prefix('URLAddress') == 'URL Address' -assert infer_prefix('isHTTPSecure') == 'Is HTTP Secure' -assert infer_prefix('isHTTPSSecure123') == 'Is HTTPS Secure 123' \ No newline at end of file + + return " ".join(title_cased_words) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index b885bea0f7..c4804d63d2 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -80,7 +80,7 @@ def _prepare_predictor_mappings(self): for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()): assert name1 == name2, "Student and teacher must have the same program structure." - assert predictor1.signature == predictor2.signature, f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" + assert predictor1.signature.equals(predictor2.signature), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" assert id(predictor1) != id(predictor2), "Student and teacher must be different objects." name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2) diff --git a/dspy/teleprompt/finetune.py b/dspy/teleprompt/finetune.py index 82fdf530d8..a56adaecbe 100644 --- a/dspy/teleprompt/finetune.py +++ b/dspy/teleprompt/finetune.py @@ -7,6 +7,8 @@ import ujson from datasets.fingerprint import Hasher +from dspy.signatures.signature import signature_to_template + # from dspy.primitives import Example from .teleprompt import Teleprompter @@ -84,8 +86,9 @@ def compile(self, student, *, teacher=None, trainset, valset=None, demo = dict(demo) # TODO: FIXME: generalize. - completion = demo.pop(predictor.signature.fields[-1].output_variable) - prompt = predictor.signature.query(dsp.Example(demos=[], **demo)).strip() + template = signature_to_template(predictor.signature) + completion = demo.pop(template.fields[-1].output_variable) + prompt = template.query(dsp.Example(demos=[], **demo)).strip() finetune_data[name_].append(dict(prompt=prompt, completion=completion)) diff --git a/dspy/teleprompt/signature_opt.py b/dspy/teleprompt/signature_opt.py index 4c047b3daf..512372f00b 100644 --- a/dspy/teleprompt/signature_opt.py +++ b/dspy/teleprompt/signature_opt.py @@ -48,6 +48,8 @@ class GenerateInstructionGivenAttempts(dspy.Signature): class SignatureOptimizer(Teleprompter): def __init__(self, prompt_model=None, metric=None, breadth=10, depth=3, init_temperature=1.4, verbose=False, track_stats=False): + if breadth <= 1: + raise ValueError("Breadth must be greater than 1") self.metric = metric self.breadth = breadth self.depth = depth @@ -60,7 +62,9 @@ def _check_candidates_equal(self, candidate1, candidate2): for p1, p2 in zip(candidate1["program"].predictors(), candidate2["program"].predictors()): if not p1.extended_signature.instructions == p2.extended_signature.instructions: return False - if not p1.extended_signature.fields[-1] == p2.extended_signature.fields[-1]: + *_, p1_last_field = p1.extended_signature.fields.values() + *_, p2_last_field = p2.extended_signature.fields.values() + if not p1_last_field == p2_last_field: return False return True @@ -103,12 +107,13 @@ def compile(self, student, *, devset, eval_kwargs): for predictor in module.predictors(): basic_instruction = None basic_prefix = None + *_, last_key = predictor.extended_signature.fields.keys() if (hasattr(predictor, 'extended_signature')): basic_instruction = predictor.extended_signature.instructions - basic_prefix = predictor.extended_signature.fields[-1].name + basic_prefix = predictor.extended_signature.fields[last_key].json_schema_extra['prefix'] else: basic_instruction = predictor.extended_signature1.instructions - basic_prefix = predictor.extended_signature1.fields[-1].name + basic_prefix = predictor.extended_signature1.fields[last_key].json_schema_extra['prefix'] if self.prompt_model: with dspy.settings.context(lm=self.prompt_model): instruct = dspy.Predict(BasicGenerateInstruction, n=self.breadth-1, temperature=self.init_temperature)(basic_instruction=basic_instruction) @@ -146,13 +151,19 @@ def compile(self, student, *, devset, eval_kwargs): # Set this new module with our instruction / prefix if (hasattr(p_new, 'extended_signature')): - p_new.extended_signature.instructions = instruction - p_new.extended_signature.fields[-1] = p_new.extended_signature.fields[-1]._replace(name=prefix) + *_, last_key = p_new.extended_signature.fields.keys() + p_new.extended_signature = p_new.extended_signature \ + .with_instructions(instruction) \ + .with_updated_fields(last_key, prefix=prefix) else: - p_new.extended_signature1.instructions = instruction - p_new.extended_signature1.fields[-1] = p_new.extended_signature1.fields[-1]._replace(name=prefix) - p_new.extended_signature2.instructions = instruction - p_new.extended_signature2.fields[-1] = p_new.extended_signature2.fields[-1]._replace(name=prefix) + *_, last_key = p_new.extended_signature1.fields.keys() + p_new.extended_signature1 = p_new.extended_signature1 \ + .with_instructions(instruction) \ + .with_updated_fields(last_key, prefix=prefix) + *_, last_key = p_new.extended_signature2.fields.keys() + p_new.extended_signature2 = p_new.extended_signature2 \ + .with_instructions(instruction) \ + .with_updated_fields(last_key, prefix=prefix) # Score the instruction / prefix if self.verbose: print(f"----------------") @@ -203,13 +214,19 @@ def compile(self, student, *, devset, eval_kwargs): # to ensure the next round of scores reflect the best possible version best_candidate = max(evaluated_candidates[id(p_old)].values(), key=lambda candidate: candidate['score']) if (hasattr(p_new, 'extended_signature')): - p_new.extended_signature.instructions = best_candidate["instruction"] - p_new.extended_signature.fields[-1] = p_new.extended_signature.fields[-1]._replace(name=best_candidate["prefix"]) + *_, last_key = p_old.extended_signature.fields.keys() + p_new.extended_signature = p_new.extended_signature \ + .with_instructions(best_candidate["instruction"]) \ + .with_updated_fields(last_key, prefix=best_candidate["prefix"]) else: - p_new.extended_signature1.instructions = best_candidate["instruction"] - p_new.extended_signature1.fields[-1] = p_new.extended_signature1.fields[-1]._replace(name=best_candidate["prefix"]) - p_new.extended_signature2.instructions = best_candidate["instruction"] - p_new.extended_signature2.fields[-1] = p_new.extended_signature2.fields[-1]._replace(name=best_candidate["prefix"]) + *_, last_key1 = p_old.extended_signature1.fields.keys() + p_new.extended_signature1 = p_new.extended_signature \ + .with_instructions(best_candidate["instruction"]) \ + .with_updated_fields(last_key1, prefix=best_candidate["prefix"]) + *_, last_key2 = p_old.extended_signature2.fields.keys() + p_new.extended_signature2 = p_new.extended_signature \ + .with_instructions(best_candidate["instruction"]) \ + .with_updated_fields(last_key2, prefix=best_candidate["prefix"]) if self.verbose: print(f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\np: {best_candidate['prefix']}") if self.verbose: print(f"Full predictor with update: ") for i,predictor in enumerate(module_clone.predictors()): diff --git a/dspy/teleprompt/signature_opt_bayesian.py b/dspy/teleprompt/signature_opt_bayesian.py index 68d7aacf0e..d5245bc931 100644 --- a/dspy/teleprompt/signature_opt_bayesian.py +++ b/dspy/teleprompt/signature_opt_bayesian.py @@ -1,5 +1,6 @@ import dsp import dspy +from dspy.signatures.signature import signature_to_template from dspy.teleprompt.teleprompt import Teleprompter from dspy.signatures import Signature from dspy.evaluate.evaluate import Evaluate @@ -114,10 +115,12 @@ def _print_full_program(self, program): if self.verbose: print(f"Predictor {i}") if (hasattr(predictor, 'extended_signature')): if self.verbose: print(f"i: {predictor.extended_signature.instructions}") - if self.verbose: print(f"p: {predictor.extended_signature.fields[-1].name}") + *_, last_field = predictor.extended_signature.fields.values() + if self.verbose: print(f"p: {last_field.json_schema_extra['prefix']}") else: if self.verbose: print(f"i: {predictor.extended_signature1.instructions}") - if self.verbose: print(f"p: {predictor.extended_signature1.fields[-1].name}") + *_, last_field = predictor.extended_signature1.fields.values() + if self.verbose: print(f"p: {last_field.json_schema_extra['prefix']}") if self.verbose: print("\n") def _print_model_history(self, model, n=1): @@ -186,8 +189,8 @@ def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo if example["augmented"]: if example_set_i not in example_set: example_set[example_set_i] = [] - fields_to_use = predictor.signature.fields - input_variable_names = [field.input_variable for field in fields_to_use] + fields_to_use = signature_to_template(predictor.signature).fields + input_variable_names = list(predictor.signature.input_fields.keys()) example_with_only_signature_fields = {key: value for key, value in example.items() if key in input_variable_names} example_string = self._create_example_string(fields_to_use, example_with_only_signature_fields) example_set[example_set_i].append(example_string) @@ -202,16 +205,28 @@ def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo basic_prefix = None if (hasattr(predictor, 'extended_signature')): basic_instruction = predictor.extended_signature.instructions - basic_prefix = predictor.extended_signature.fields[-1].name + *_, last_field = predictor.extended_signature.fields.values() + basic_prefix = last_field.json_schema_extra["prefix"] else: basic_instruction = predictor.extended_signature1.instructions - basic_prefix = predictor.extended_signature1.fields[-1].name + *_, last_field = predictor.extended_signature1.fields.values() + basic_prefix = last_field.json_schema_extra["prefix"] with dspy.settings.context(lm=self.prompt_model): # Data & Examples if view_data and view_examples: + if 1 not in example_sets[id(predictor)].keys(): + raise ValueError("No examples found for the given predictor") instruct = None - for i in range(1,self.n): - new_instruct = dspy.Predict(BasicGenerateInstructionWithExamplesAndDataObservations, n=1, temperature=self.init_temperature)(basic_instruction=basic_instruction, observations=self.observations, examples=example_sets[id(predictor)][i]) + for i in range(1, self.n): + new_instruct = dspy.Predict( + BasicGenerateInstructionWithExamplesAndDataObservations, + n=1, + temperature=self.init_temperature + )( + basic_instruction=basic_instruction, + observations=self.observations, + examples=example_sets[id(predictor)][i] + ) if not instruct: instruct = new_instruct else: @@ -224,7 +239,14 @@ def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo elif view_examples: instruct = None for i in range(1,self.n): # Note: skip over the first example set which is empty - new_instruct = dspy.Predict(BasicGenerateInstructionWithExamples, n=1, temperature=self.init_temperature)(basic_instruction=basic_instruction, examples=example_sets[id(predictor)][i]) + new_instruct = dspy.Predict( + BasicGenerateInstructionWithExamples, + n=1, + temperature=self.init_temperature + )( + basic_instruction=basic_instruction, + examples=example_sets[id(predictor)][i] + ) if not instruct: instruct = new_instruct else: @@ -303,8 +325,11 @@ def objective(trial): p_demo_candidates = demo_candidates[id(p_old)] # Suggest the index of the instruction candidate to use in our trial - instruction_idx = trial.suggest_categorical(f"{id(p_old)}_predictor_instruction",range(len(p_instruction_candidates))) - demos_idx = trial.suggest_categorical(f"{id(p_old)}_predictor_demos",range(len(p_demo_candidates))) + #instruction_idx = trial.suggest_categorical(f"{id(p_old)}_predictor_instruction",range(len(p_instruction_candidates))) + #demos_idx = trial.suggest_categorical(f"{id(p_old)}_predictor_demos",range(len(p_demo_candidates))) + instruction_idx = trial.suggest_int(f"{id(p_old)}_predictor_instruction",low=0, high=len(p_instruction_candidates)-1) + demos_idx = trial.suggest_int(f"{id(p_old)}_predictor_demos",low=0, high=len(p_demo_candidates)-1) + trial_logs[trial_num][f"{id(p_old)}_predictor_instruction"] = instruction_idx trial_logs[trial_num][f"{id(p_old)}_predictor_demos"] = demos_idx @@ -314,8 +339,10 @@ def objective(trial): selected_prefix = selected_candidate.proposed_prefix_for_output_field.strip('"').strip() # Use this candidates in our program - p_new.extended_signature.instructions = selected_instruction - p_new.extended_signature.fields[-1] = p_new.extended_signature.fields[-1]._replace(name=selected_prefix) + *_, last_field = p_new.extended_signature.fields.keys() + p_new.extended_signature = p_new.extended_signature \ + .with_instructions(selected_instruction) \ + .with_updated_fields(last_field, prefix=selected_prefix) # Get the selected demos selected_demos = p_demo_candidates[demos_idx] @@ -353,8 +380,9 @@ def objective(trial): trial_num += 1 raise optuna.TrialPruned() - if self.verbose: print(f"Fully evaled score: {curr_weighted_avg_score}") - self._print_model_history(self.task_model, n=1) + if self.verbose: + print(f"Fully evaled score: {curr_weighted_avg_score}") + self._print_model_history(self.task_model, n=1) score = curr_weighted_avg_score trial_logs[trial_num]["score"] = curr_weighted_avg_score diff --git a/dspy/utils/__init__.py b/dspy/utils/__init__.py new file mode 100644 index 0000000000..c6f239df08 --- /dev/null +++ b/dspy/utils/__init__.py @@ -0,0 +1 @@ +from .dummies import * \ No newline at end of file diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py new file mode 100644 index 0000000000..a0c997145b --- /dev/null +++ b/dspy/utils/dummies.py @@ -0,0 +1,144 @@ +import random +from dsp.modules import LM +from typing import List, Union, Dict +import numpy as np +from dsp.utils.utils import dotdict +import re + + +class DummyLM(LM): + """Dummy language model for unit testing purposes.""" + + def __init__(self, answers: Union[List[str], Dict[str,str]], follow_examples: bool = False): + """ + Initializes the dummy language model. + Parameters: + - answers: A list of strings or a dictionary with string keys and values. + - follow_examples: If True, and the prompt contains an example exactly equal to the prompt, + the dummy model will return the next string in the list for each request. + If a list is provided, the dummy model will return the next string in the list for each request. + If a dictionary is provided, the dummy model will return the value corresponding to the key that matches the prompt. + """ + super().__init__("dummy-model") + self.provider = "dummy" + self.answers = answers + self.follow_examples = follow_examples + + def basic_request(self, prompt, n=1, **kwargs): + """Generates a dummy response based on the prompt.""" + dummy_response = {"choices": []} + for _ in range(n): + answer = None + + if self.follow_examples: + prefix = prompt.split("\n")[-1] + _instructions, _format, *examples, _output = prompt.split("\n---\n") + examples_str = "\n".join(examples) + possible_answers = re.findall(prefix + r"\s*(.*)", examples_str) + if possible_answers: + # We take the last answer, as the first one is just from + # the "Follow the following format" section. + answer = possible_answers[-1] + print(f"DummyLM got found previous example for {prefix} with value {answer=}") + else: + print(f"DummyLM couldn't find previous example for {prefix=}") + + if answer is None: + if isinstance(self.answers, dict): + answer = next((v for k, v in self.answers.items() if k in prompt), None) + else: + if len(self.answers) > 0: + answer = self.answers[0] + self.answers = self.answers[1:] + + if answer is None: + answer = "No more responses" + + # Mimic the structure of a real language model response. + dummy_response["choices"].append({ + "text": answer, + "finish_reason": "simulated completion", + }) + + RED, GREEN, RESET = '\033[91m', '\033[92m', '\033[0m' + print("=== DummyLM ===") + print(prompt, end="") + print(f"{RED}{answer}{RESET}") + print("===") + + # Simulate processing and storing the request and response. + history_entry = { + "prompt": prompt, + "response": dummy_response, + "kwargs": kwargs, + "raw_kwargs": kwargs, + } + self.history.append(history_entry) + + return dummy_response + + def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): + """Retrieves dummy completions.""" + response = self.basic_request(prompt, **kwargs) + choices = response["choices"] + + # Filter choices and return text completions. + completions = [choice["text"] for choice in choices] + + return completions + + def get_convo(self, index): + """Get the prompt + anwer from the ith message""" + return self.history[index]['prompt'] \ + + " " \ + + self.history[index]['response']['choices'][0]['text'] + + +def dummy_rm(passages=()): + if not passages: + def inner(query:str, *, k:int, **kwargs): + assert False, "No passages defined" + return inner + max_length = max(map(len, passages)) + 100 + vectorizer = DummyVectorizer(max_length) + passage_vecs = vectorizer(passages) + def inner(query:str, *, k:int, **kwargs): + assert k <= len(passages) + query_vec = vectorizer([query])[0] + scores = passage_vecs @ query_vec + largest_idx = (-scores).argsort()[:k] + #return dspy.Prediction(passages=[passages[i] for i in largest_idx]) + return [dotdict(dict(long_text=passages[i])) for i in largest_idx] + return inner + + +class DummyVectorizer: + """Simple vectorizer based on n-grams""" + def __init__(self, max_length=100, n_gram=2): + self.max_length = max_length + self.n_gram = n_gram + self.P = 10**9 + 7 # A large prime number + random.seed(123) + self.coeffs = [random.randrange(1, self.P) for _ in range(n_gram)] + + def _hash(self, gram): + """Hashes a string using a polynomial hash function""" + h = 1 + for coeff, c in zip(self.coeffs, gram): + h = h * coeff + ord(c) + h %= self.P + return h % self.max_length + + def __call__(self, texts: List[str]) -> np.ndarray: + vecs = [] + for text in texts: + grams = [text[i:i+self.n_gram] for i in range(len(text) - self.n_gram + 1)] + vec = [0] * self.max_length + for gram in grams: + vec[self._hash(gram)] += 1 + vecs.append(vec) + + vecs = np.array(vecs, dtype=np.float32) + vecs -= np.mean(vecs, axis=1, keepdims=True) + vecs /= np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-10 # Added epsilon to avoid division by zero + return vecs diff --git a/tests/evaluate/test_evaluate.py b/tests/evaluate/test_evaluate.py new file mode 100644 index 0000000000..3af3ae9dab --- /dev/null +++ b/tests/evaluate/test_evaluate.py @@ -0,0 +1,59 @@ +import dsp, dspy +from dspy.evaluate.evaluate import Evaluate +from dspy.evaluate.metrics import answer_exact_match +from dspy.predict import Predict +from dspy.utils.dummies import DummyLM + +def new_example(question, answer): + """Helper function to create a new example.""" + return dspy.Example( + question=question, + answer=answer, + ).with_inputs("question") + +def test_evaluate_initialization(): + devset = [new_example("What is 1+1?", "2")] + ev = Evaluate( + devset=devset, + metric=answer_exact_match, + display_progress=False, + ) + assert ev.devset == devset + assert ev.metric == answer_exact_match + assert ev.num_threads == len(devset) + assert ev.display_progress == False + +def test_evaluate_call(): + dspy.settings.configure(lm=DummyLM({"What is 1+1?": "2", "What is 2+2?": "4"})) + devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")] + program = Predict("question -> answer") + assert program(question="What is 1+1?").answer == "2" + ev = Evaluate( + devset=devset, + metric=answer_exact_match, + display_progress=False, + ) + score = ev(program) + assert score == 100.0 + +def test_evaluate_call_bad(): + dspy.settings.configure(lm=DummyLM({"What is 1+1?": "0", "What is 2+2?": "0"})) + devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")] + program = Predict("question -> answer") + ev = Evaluate( + devset=devset, + metric=answer_exact_match, + display_progress=False, + ) + score = ev(program) + assert score == 0.0 + +def test_evaluate_display_table(): + devset = [new_example("What is 1+1?", "2")] + ev = Evaluate( + devset=devset, + metric=answer_exact_match, + display_table=True, + ) + assert ev.display_table == True + diff --git a/tests/evaluate/test_metrics.py b/tests/evaluate/test_metrics.py new file mode 100644 index 0000000000..f04148251b --- /dev/null +++ b/tests/evaluate/test_metrics.py @@ -0,0 +1,32 @@ +# FILEPATH: /Users/ahle/repos/dspy/tests/evaluate/test_metrics.py + +import dsp, dspy +from dspy.evaluate.metrics import answer_exact_match +from dspy.predict import Predict + +def test_answer_exact_match_string(): + example = dspy.Example( + question="What is 1+1?", + answer="2", + ).with_inputs("question") + pred = Predict("question -> answer") + pred.answer = "2" + assert answer_exact_match(example, pred) + +def test_answer_exact_match_list(): + example = dspy.Example( + question="What is 1+1?", + answer=["2", "two"], + ).with_inputs("question") + pred = Predict("question -> answer") + pred.answer = "2" + assert answer_exact_match(example, pred) + +def test_answer_exact_match_no_match(): + example = dspy.Example( + question="What is 1+1?", + answer="2", + ).with_inputs("question") + pred = Predict("question -> answer") + pred.answer = "3" + assert not answer_exact_match(example, pred) \ No newline at end of file diff --git a/tests/examples/test_baleen.py b/tests/examples/test_baleen.py new file mode 100644 index 0000000000..ab14458444 --- /dev/null +++ b/tests/examples/test_baleen.py @@ -0,0 +1,136 @@ +import pytest +from dsp.utils import deduplicate +import dspy.evaluate +import dspy +from dspy.datasets import HotPotQA +from dspy.evaluate.evaluate import Evaluate +from dspy.teleprompt.bootstrap import BootstrapFewShot + + +class GenerateAnswer(dspy.Signature): + """Answer questions with short factoid answers.""" + + context = dspy.InputField(desc="may contain relevant facts") + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + + +class GenerateSearchQuery(dspy.Signature): + """Write a simple search query that will help answer a complex question.""" + + context = dspy.InputField(desc="may contain relevant facts") + question = dspy.InputField() + query = dspy.OutputField() + + +class SimplifiedBaleen(dspy.Module): + def __init__(self, passages_per_hop=3, max_hops=2): + super().__init__() + + self.generate_query = [ + dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops) + ] + self.retrieve = dspy.Retrieve(k=passages_per_hop) + self.generate_answer = dspy.ChainOfThought(GenerateAnswer) + self.max_hops = max_hops + + def forward(self, question): + context = [] + + for hop in range(self.max_hops): + query = self.generate_query[hop](context=context, question=question).query + passages = self.retrieve(query).passages + context = deduplicate(context + passages) + + pred = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=pred.answer) + + +def load_hotpotqa(): + # Load the dataset. + dataset = HotPotQA( + train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0 + ) + # Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata. + trainset = [x.with_inputs("question") for x in dataset.train] + devset = [x.with_inputs("question") for x in dataset.dev] + return trainset, devset + + +# @pytest.mark.slow_test +# TODO: Find a way to make this test run without openai +def _test_baleen(): + lm = dspy.OpenAI(model="gpt-3.5-turbo") + rm = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") + dspy.settings.configure(lm=lm, rm=rm) + + # Ask any question you like to this simple RAG program. + my_question = "How many storeys are in the castle that David Gregory inherited?" + + # Get the prediction. This contains `pred.context` and `pred.answer`. + uncompiled_baleen = SimplifiedBaleen() # uncompiled (i.e., zero-shot) program + pred = uncompiled_baleen(my_question) + + assert pred.answer == "five" + + +def validate_context_and_answer_and_hops(example, pred, trace=None): + if not dspy.evaluate.answer_exact_match(example, pred): + return False + if not dspy.evaluate.answer_passage_match(example, pred): + return False + + hops = [example.question] + [ + outputs.query for *_, outputs in trace if "query" in outputs + ] + + if max([len(h) for h in hops]) > 100: + return False + if any( + dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) + for idx in range(2, len(hops)) + ): + return False + + return True + + +def gold_passages_retrieved(example, pred, trace=None): + gold_titles = set(map(dspy.evaluate.normalize_text, example["gold_titles"])) + found_titles = set( + map(dspy.evaluate.normalize_text, [c.split(" | ")[0] for c in pred.context]) + ) + + return gold_titles.issubset(found_titles) + + +# @pytest.mark.slow_test +# TODO: Find a way to make this test run without the slow hotpotqa dataset +def _test_compiled_baleen(): + trainset, devset = load_hotpotqa() + lm = dspy.OpenAI(model="gpt-3.5-turbo") + rm = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") + dspy.settings.configure(lm=lm, rm=rm) + + uncompiled_baleen = SimplifiedBaleen() # uncompiled (i.e., zero-shot) program + + teleprompter = BootstrapFewShot(metric=validate_context_and_answer_and_hops) + compiled_baleen = teleprompter.compile( + SimplifiedBaleen(), + teacher=SimplifiedBaleen(passages_per_hop=2), + trainset=trainset, + ) + + evaluate_on_hotpotqa = Evaluate( + devset=devset, num_threads=1, display_progress=True, display_table=5 + ) + uncompiled_baleen_retrieval_score = evaluate_on_hotpotqa( + uncompiled_baleen, metric=gold_passages_retrieved, display=False + ) + # assert uncompiled_baleen_retrieval_score / 100 == 18 / 50 + + compiled_baleen_retrieval_score = evaluate_on_hotpotqa( + compiled_baleen, metric=gold_passages_retrieved + ) + # assert compiled_baleen_retrieval_score / 100 == 27 / 50 + assert uncompiled_baleen_retrieval_score < compiled_baleen_retrieval_score \ No newline at end of file diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py new file mode 100644 index 0000000000..8ed180cff4 --- /dev/null +++ b/tests/functional/test_functional.py @@ -0,0 +1,401 @@ +import datetime +import json +import textwrap +import pydantic +from pydantic import Field, BaseModel, field_validator +from typing import Annotated + +import pytest + +import dspy +from dspy.functional import predictor, cot, FunctionalModule, TypedPredictor +from dspy.primitives.example import Example +from dspy.teleprompt.bootstrap import BootstrapFewShot +from dspy.utils.dummies import DummyLM + + +def test_simple(): + @predictor + def hard_question(topic: str) -> str: + """Think of a hard factual question about a topic.""" + + expected = "What is the speed of light?" + lm = DummyLM([expected]) + dspy.settings.configure(lm=lm) + + question = hard_question(topic="Physics") + lm.inspect_history(n=2) + + assert question == expected + + +def test_simple_type(): + class Question(pydantic.BaseModel): + value: str + + @predictor + def hard_question(topic: str) -> Question: + """Think of a hard factual question about a topic.""" + + expected = "What is the speed of light?" + lm = DummyLM([f'{{"value": "{expected}"}}']) + dspy.settings.configure(lm=lm) + + question = hard_question(topic="Physics") + + assert isinstance(question, Question) + assert question.value == expected + + +def test_simple_type_input(): + class Question(pydantic.BaseModel): + value: str + + class Answer(pydantic.BaseModel): + value: str + + @predictor + def answer(question: Question) -> Answer: + pass + + question = Question(value="What is the speed of light?") + lm = DummyLM([f'{{"value": "3e8"}}']) + dspy.settings.configure(lm=lm) + + result = answer(question=question) + + assert result == Answer(value="3e8") + + +def test_simple_class(): + class Answer(pydantic.BaseModel): + value: float + certainty: float + comments: list[str] = pydantic.Field( + description="At least two comments about the answer" + ) + + class QA(dspy.Module): + @predictor + def hard_question(self, topic: str) -> str: + """Think of a hard factual question about a topic. It should be answerable with a number.""" + + @cot + def answer(self, question: Annotated[str, "Question to answer"]) -> Answer: + pass + + def forward(self, **kwargs): + question = self.hard_question(**kwargs) + return (question, self.answer(question=question)) + + expected = Answer( + value=3e8, + certainty=0.9, + comments=["It is the speed of light", "It is a constant"], + ) + + lm = DummyLM( + [ + "What is the speed of light?", + "Some bad reasoning, 3e8 m/s.", + "3e8", # Bad answer 1 + "Some good reasoning...", + expected.model_dump_json(), # Good answer + ] + ) + dspy.settings.configure(lm=lm) + + qa = QA() + question, answer = qa(topic="Physics") + + assert question == "What is the speed of light?" + assert answer == expected + + +def test_simple_oop(): + class Question(pydantic.BaseModel): + value: str + + class MySignature(dspy.Signature): + topic: str = dspy.InputField() + output: Question = dspy.OutputField() + + # Run the signature + program = TypedPredictor(MySignature) + expected = "What is the speed of light?" + lm = DummyLM( + [ + Question(value=expected).model_dump_json(), + ] + ) + dspy.settings.configure(lm=lm) + + question = program(topic="Physics").output + + assert isinstance(question, Question) + assert question.value == expected + + +def test_equivalent_signatures(): + class ClassSignature(dspy.Signature): + input: str = dspy.InputField() + output: str = dspy.OutputField() + + @predictor + def output(input: str) -> str: + pass + + function_signature = output.predictor.signature + + simple_signature = dspy.Signature("input -> output") + + assert ClassSignature.equals(function_signature) + assert ClassSignature.equals(simple_signature) + + +def test_named_params(): + class QA(FunctionalModule): + @predictor + def hard_question(self, topic: str) -> str: + """Think of a hard factual question about a topic. It should be answerable with a number.""" + + @cot + def answer(self, question: str) -> str: + pass + + qa = QA() + named_predictors = list(qa.named_predictors()) + assert len(named_predictors) == 2 + names, _ = zip(*qa.named_predictors()) + assert set(names) == {"hard_question.predictor", "answer.predictor"} + + +def test_bootstrap_effectiveness(): + class SimpleModule(FunctionalModule): + @predictor + def output(self, input: str) -> str: + pass + + def forward(self, **kwargs): + return self.output(**kwargs) + + def simple_metric(example, prediction, trace=None): + return example.output == prediction.output + + examples = [ + ex.with_inputs("input") + for ex in ( + Example(input="What is the color of the sky?", output="blue"), + Example( + input="What does the fox say?", + output="Ring-ding-ding-ding-dingeringeding!", + ), + ) + ] + trainset = [examples[0]] + valset = [examples[1]] + + # This test verifies if the bootstrapping process improves the student's predictions + student = SimpleModule() + teacher = SimpleModule() + assert student.output.predictor.signature.equals(teacher.output.predictor.signature) + + lm = DummyLM(["blue", "Ring-ding-ding-ding-dingeringeding!"], follow_examples=True) + dspy.settings.configure(lm=lm, trace=[]) + + bootstrap = BootstrapFewShot( + metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1 + ) + compiled_student = bootstrap.compile(student, teacher=teacher, trainset=trainset) + + lm.inspect_history(n=2) + + # Check that the compiled student has the correct demos + assert len(compiled_student.output.predictor.demos) == 1 + assert compiled_student.output.predictor.demos[0].input == trainset[0].input + assert compiled_student.output.predictor.demos[0].output == trainset[0].output + + # Test the compiled student's prediction. + # We are using a DummyLM with follow_examples=True, which means that + # even though it would normally reply with "Ring-ding-ding-ding-dingeringeding!" + # on the second output, if it seems an example that perfectly matches the + # prompt, it will use that instead. That is why we expect "blue" here. + prediction = compiled_student(input=trainset[0].input) + assert prediction == trainset[0].output + + assert lm.get_convo(-1) == textwrap.dedent( + """\ + Given the fields `input`, produce the fields `output`. + + --- + + Follow the following format. + + Input: ${input} + Output: ${output}. Respond with a single str value + + --- + + Input: What is the color of the sky? + Output: blue + + --- + + Input: What is the color of the sky? + Output: blue""" + ) + + +def test_regex(): + class TravelInformation(BaseModel): + origin: str = Field(pattern=r"^[A-Z]{3}$") + destination: str = Field(pattern=r"^[A-Z]{3}$") + date: datetime.date + + @predictor + def flight_information(email: str) -> TravelInformation: + pass + + email = textwrap.dedent( + """\ + We're excited to welcome you aboard your upcoming flight from + John F. Kennedy International Airport (JFK) to Los Angeles International Airport (LAX) + on December 25, 2022. Here's everything you need to know before you take off: ... + """ + ) + lm = DummyLM( + [ + # Example with a bad origin code. + '{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}', + # Fixed + '{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}', + ] + ) + dspy.settings.configure(lm=lm) + + assert flight_information(email=email) == TravelInformation( + origin="JFK", destination="LAX", date=datetime.date(2022, 12, 25) + ) + + +def test_raises(): + class TravelInformation(BaseModel): + origin: str = Field(pattern=r"^[A-Z]{3}$") + destination: str = Field(pattern=r"^[A-Z]{3}$") + date: datetime.date + + @predictor + def flight_information(email: str) -> TravelInformation: + pass + + lm = DummyLM( + [ + "A list of bad inputs", + '{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}', + '{"origin": "JFK", "destination": "LAX", "date": "bad date"}', + ] + ) + dspy.settings.configure(lm=lm) + + with pytest.raises(ValueError): + flight_information(email="Some email") + + +def test_multi_errors(): + class TravelInformation(BaseModel): + origin: str = Field(pattern=r"^[A-Z]{3}$") + destination: str = Field(pattern=r"^[A-Z]{3}$") + date: datetime.date + + @predictor + def flight_information(email: str) -> TravelInformation: + pass + + lm = DummyLM( + [ + # First origin is wrong, then destination, then all is good + '{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}', + '{"origin": "JFK", "destination": "LA0", "date": "2022-12-25"}', + '{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}', + ] + ) + dspy.settings.configure(lm=lm) + + assert flight_information(email="Some email") == TravelInformation( + origin="JFK", destination="LAX", date=datetime.date(2022, 12, 25) + ) + assert lm.get_convo(-1) == textwrap.dedent( + """\ + Given the fields `email`, produce the fields `flight_information`. + + --- + + Follow the following format. + + Email: ${email} + + Past Error (flight_information): An error to avoid in the future + + Past Error (flight_information, 2): An error to avoid in the future + + Flight Information: ${flight_information}. Respond with a single JSON object using the schema {"properties": {"origin": {"pattern": "^[A-Z]{3}$", "title": "Origin", "type": "string"}, "destination": {"pattern": "^[A-Z]{3}$", "title": "Destination", "type": "string"}, "date": {"format": "date", "title": "Date", "type": "string"}}, "required": ["origin", "destination", "date"], "title": "TravelInformation", "type": "object"} + + --- + + Email: Some email + + Past Error (flight_information): 1 validation error for TravelInformation origin String should match pattern '^[A-Z]{3}$' [type=string_pattern_mismatch, input_value='JF0', input_type=str] For further information visit https://errors.pydantic.dev/2.5/v/string_pattern_mismatch + + Past Error (flight_information, 2): 1 validation error for TravelInformation destination String should match pattern '^[A-Z]{3}$' [type=string_pattern_mismatch, input_value='LA0', input_type=str] For further information visit https://errors.pydantic.dev/2.5/v/string_pattern_mismatch + + Flight Information: {"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}""" + ) + + +def test_field_validator(): + class UserDetails(BaseModel): + name: str + age: int + + @field_validator("name") + @classmethod + def validate_name(cls, v): + if v.upper() != v: + raise ValueError("Name must be in uppercase.") + return v + + @predictor + def get_user_details() -> UserDetails: + pass + + # Keep making the mistake (lower case name) until we run + # out of retries. + lm = DummyLM( + [ + '{"name": "lower case name", "age": 25}', + ] + * 10 + ) + dspy.settings.configure(lm=lm) + + with pytest.raises(ValueError): + get_user_details() + + assert lm.get_convo(-1) == textwrap.dedent( + """\ + Given the fields , produce the fields `get_user_details`. + + --- + + Follow the following format. + + Past Error (get_user_details): An error to avoid in the future + Past Error (get_user_details, 2): An error to avoid in the future + Get User Details: ${get_user_details}. Respond with a single JSON object using the schema {"properties": {"name": {"title": "Name", "type": "string"}, "age": {"title": "Age", "type": "integer"}}, "required": ["name", "age"], "title": "UserDetails", "type": "object"} + + --- + + Past Error (get_user_details): 1 validation error for UserDetails name Value error, Name must be in uppercase. [type=value_error, input_value='lower case name', input_type=str] For further information visit https://errors.pydantic.dev/2.5/v/value_error + Past Error (get_user_details, 2): 1 validation error for UserDetails name Value error, Name must be in uppercase. [type=value_error, input_value='lower case name', input_type=str] For further information visit https://errors.pydantic.dev/2.5/v/value_error + Get User Details: {"name": "lower case name", "age": 25}""" + ) diff --git a/tests/predict/test_aggregation.py b/tests/predict/test_aggregation.py new file mode 100644 index 0000000000..2c5f705fe6 --- /dev/null +++ b/tests/predict/test_aggregation.py @@ -0,0 +1,47 @@ +from dspy.predict.aggregation import majority +from dspy.primitives.prediction import Prediction, Completions +from dsp.utils import normalize_text + + +def test_majority_with_prediction(): + prediction = Prediction.from_completions( + [{"answer": "2"}, {"answer": "2"}, {"answer": "3"}] + ) + result = majority(prediction) + assert result.completions[0]["answer"] == "2" + + +def test_majority_with_completions(): + completions = Completions([{"answer": "2"}, {"answer": "2"}, {"answer": "3"}]) + result = majority(completions) + assert result.completions[0]["answer"] == "2" + + +def test_majority_with_list(): + completions = [{"answer": "2"}, {"answer": "2"}, {"answer": "3"}] + result = majority(completions) + assert result.completions[0]["answer"] == "2" + + +def test_majority_with_normalize(): + completions = [{"answer": "2"}, {"answer": " 2"}, {"answer": "3"}] + result = majority(completions, normalize=normalize_text) + assert result.completions[0]["answer"] == "2" + + +def test_majority_with_field(): + completions = [ + {"answer": "2", "other": "1"}, + {"answer": "2", "other": "1"}, + {"answer": "3", "other": "2"}, + ] + result = majority(completions, field="other") + assert result.completions[0]["other"] == "1" + + +def test_majority_with_no_majority(): + completions = [{"answer": "2"}, {"answer": "3"}, {"answer": "4"}] + result = majority(completions) + assert ( + result.completions[0]["answer"] == "2" + ) # The first completion is returned in case of a tie diff --git a/tests/predict/test_chain_of_thought.py b/tests/predict/test_chain_of_thought.py new file mode 100644 index 0000000000..c1d08e729c --- /dev/null +++ b/tests/predict/test_chain_of_thought.py @@ -0,0 +1,35 @@ +import textwrap +import dspy +from dspy import ChainOfThought +from dspy.utils import DummyLM + + +def test_initialization_with_string_signature(): + lm = DummyLM(["find the number after 1", "2"]) + dspy.settings.configure(lm=lm) + predict = ChainOfThought("question -> answer") + assert list(predict.extended_signature.output_fields.keys()) == [ + "rationale", + "answer", + ] + assert predict(question="What is 1+1?").answer == "2" + + print(lm.get_convo(-1)) + assert lm.get_convo(-1) == textwrap.dedent( + """\ + Given the fields `question`, produce the fields `answer`. + + --- + + Follow the following format. + + Question: ${question} + Reasoning: Let's think step by step in order to ${produce the answer}. We ... + Answer: ${answer} + + --- + + Question: What is 1+1? + Reasoning: Let's think step by step in order to find the number after 1 + Answer: 2""" + ) diff --git a/tests/predict/test_chain_of_thought_with_hint.py b/tests/predict/test_chain_of_thought_with_hint.py new file mode 100644 index 0000000000..b5e62425dc --- /dev/null +++ b/tests/predict/test_chain_of_thought_with_hint.py @@ -0,0 +1,42 @@ +import dspy +from dspy import ChainOfThoughtWithHint +from dspy.utils import DummyLM + + +def test_cot_with_no_hint(): + lm = DummyLM(["find the number after 1", "2"]) + dspy.settings.configure(lm=lm) + predict = ChainOfThoughtWithHint("question -> answer") + assert list(predict.extended_signature2.output_fields.keys()) == [ + "rationale", + "hint", + "answer", + ] + assert predict(question="What is 1+1?").answer == "2" + + final_convo = lm.get_convo(-1) + assert final_convo.endswith( + "Question: What is 1+1?\n" + "Reasoning: Let's think step by step in order to find the number after 1\n" + "Answer: 2" + ) + + +def test_cot_with_hint(): + lm = DummyLM(["find the number after 1", "2"]) + dspy.settings.configure(lm=lm) + predict = ChainOfThoughtWithHint("question -> answer") + assert list(predict.extended_signature2.output_fields.keys()) == [ + "rationale", + "hint", + "answer", + ] + assert predict(question="What is 1+1?", hint="think small").answer == "2" + + final_convo = lm.get_convo(-1) + assert final_convo.endswith( + "Question: What is 1+1?\n\n" + "Reasoning: Let's think step by step in order to find the number after 1\n\n" + "Hint: think small\n\n" + "Answer: 2" + ) diff --git a/tests/predict/test_knn.py b/tests/predict/test_knn.py new file mode 100644 index 0000000000..62cf96682b --- /dev/null +++ b/tests/predict/test_knn.py @@ -0,0 +1,55 @@ +import pytest +import numpy as np +import dsp, dspy +from dspy.utils import DummyVectorizer +from dspy.predict import KNN + + +def mock_example(question: str, answer: str) -> dsp.Example: + """Creates a mock DSP example with specified question and answer.""" + return dspy.Example(question=question, answer=answer).with_inputs("question") + + +@pytest.fixture +def setup_knn(): + """Sets up a KNN instance with a mocked vectorizer for testing.""" + dsp.SentenceTransformersVectorizer = DummyVectorizer + trainset = [ + mock_example("What is the capital of France?", "Paris"), + mock_example("What is the largest ocean?", "Pacific"), + mock_example("What is 2+2?", "4"), + ] + knn = KNN(k=2, trainset=trainset) + return knn + + +def test_knn_initialization(setup_knn): + """Tests the KNN initialization and checks if the trainset vectors are correctly created.""" + knn = setup_knn + assert knn.k == 2, "Incorrect k value" + assert len(knn.trainset_vectors) == 3, "Incorrect size of trainset vectors" + assert isinstance( + knn.trainset_vectors, np.ndarray + ), "Trainset vectors should be a NumPy array" + + +def test_knn_query(setup_knn): + """Tests the KNN query functionality for retrieving the nearest neighbors.""" + knn = setup_knn + query = {"question": "What is 3+3?"} # A query close to "What is 2+2?" + nearest_samples = knn(**query) + assert len(nearest_samples) == 2, "Incorrect number of nearest samples returned" + assert nearest_samples[0].answer == "4", "Incorrect nearest sample returned" + + +def test_knn_query_specificity(setup_knn): + """Tests the KNN query functionality for specificity of returned examples.""" + knn = setup_knn + query = { + "question": "What is the capital of Germany?" + } # A query close to "What is the capital of France?" + nearest_samples = knn(**query) + assert len(nearest_samples) == 2, "Incorrect number of nearest samples returned" + assert "Paris" in [ + sample.answer for sample in nearest_samples + ], "Expected Paris to be a nearest sample answer" diff --git a/tests/predict/test_multi_chain_comparison.py b/tests/predict/test_multi_chain_comparison.py new file mode 100644 index 0000000000..8c936a2d80 --- /dev/null +++ b/tests/predict/test_multi_chain_comparison.py @@ -0,0 +1,38 @@ +import dspy +from dspy.utils.dummies import DummyLM + + +def test_basic_example(): + class BasicQA(dspy.Signature): + """Answer questions with short factoid answers.""" + + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + + # Example completions generated by a model for reference + completions = [ + dspy.Prediction( + rationale="I recall that during clear days, the sky often appears this color.", + answer="blue", + ), + dspy.Prediction( + rationale="Based on common knowledge, I believe the sky is typically seen as this color.", + answer="green", + ), + dspy.Prediction( + rationale="From images and depictions in media, the sky is frequently represented with this hue.", + answer="blue", + ), + ] + + # Pass signature to MultiChainComparison module + compare_answers = dspy.MultiChainComparison(BasicQA) + + # Call the MultiChainComparison on the completions + question = "What is the color of the sky?" + lm = DummyLM(["my rationale", "blue"]) + dspy.settings.configure(lm=lm) + final_pred = compare_answers(completions, question=question) + + assert final_pred.rationale == "my rationale" + assert final_pred.answer == "blue" diff --git a/tests/predict/test_predict.py b/tests/predict/test_predict.py new file mode 100644 index 0000000000..e44b3a135c --- /dev/null +++ b/tests/predict/test_predict.py @@ -0,0 +1,91 @@ +import dspy +from dspy import Predict, Signature +from dspy.utils.dummies import DummyLM + + +def test_initialization_with_string_signature(): + signature_string = "input1, input2 -> output" + predict = Predict(signature_string) + expected_instruction = ( + "Given the fields `input1`, `input2`, produce the fields `output`." + ) + assert predict.signature.instructions == expected_instruction + assert predict.signature.instructions == Signature(signature_string).instructions + + +def test_reset_method(): + predict_instance = Predict("input -> output") + predict_instance.lm = "modified" + predict_instance.traces = ["trace"] + predict_instance.train = ["train"] + predict_instance.demos = ["demo"] + predict_instance.reset() + assert predict_instance.lm is None + assert predict_instance.traces == [] + assert predict_instance.train == [] + assert predict_instance.demos == [] + + +def test_dump_and_load_state(): + predict_instance = Predict("input -> output") + predict_instance.lm = "lm_state" + dumped_state = predict_instance.dump_state() + new_instance = Predict("input -> output") + new_instance.load_state(dumped_state) + assert new_instance.lm == "lm_state" + + +def test_call_method(): + predict_instance = Predict("input -> output") + lm = DummyLM(["test output"]) + dspy.settings.configure(lm=lm) + result = predict_instance(input="test input") + assert result.output == "test output" + assert lm.get_convo(-1) == ( + "Given the fields `input`, produce the fields `output`.\n" + "\n---\n\n" + "Follow the following format.\n\n" + "Input: ${input}\n" + "Output: ${output}\n" + "\n---\n\n" + "Input: test input\n" + "Output: test output" + ) + + +def test_dump_load_state(): + predict_instance = Predict(Signature("input -> output", "original instructions")) + dumped_state = predict_instance.dump_state() + new_instance = Predict(Signature("input -> output", "new instructions")) + new_instance.load_state(dumped_state) + assert new_instance.signature.instructions == "original instructions" + + +def test_forward_method(): + program = Predict("question -> answer") + dspy.settings.configure(lm=DummyLM([])) + result = program(question="What is 1+1?").answer + assert result == "No more responses" + + +def test_forward_method2(): + program = Predict("question -> answer1, answer2") + dspy.settings.configure(lm=DummyLM(["my first answer", "my second answer"])) + result = program(question="What is 1+1?") + assert result.answer1 == "my first answer" + assert result.answer2 == "my second answer" + + +def test_config_management(): + predict_instance = Predict("input -> output") + predict_instance.update_config(new_key="value") + config = predict_instance.get_config() + assert "new_key" in config and config["new_key"] == "value" + + +def test_multi_output(): + program = Predict("question -> answer", n=2) + dspy.settings.configure(lm=DummyLM(["my first answer", "my second answer"])) + results = program(question="What is 1+1?") + assert results.completions.answer[0] == "my first answer" + assert results.completions.answer[1] == "my second answer" diff --git a/tests/predict/test_program_of_thought.py b/tests/predict/test_program_of_thought.py new file mode 100644 index 0000000000..2aa153a1d6 --- /dev/null +++ b/tests/predict/test_program_of_thought.py @@ -0,0 +1,121 @@ +from dspy import Signature, ProgramOfThought +import dspy +from dspy.utils import DummyLM +import textwrap + +class BasicQA(Signature): + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + +def test_pot_code_generation(): + pot = ProgramOfThought(BasicQA) + lm = DummyLM([ + "Reason_A", + "```python\nresult = 1+1\n```", + "Reason_B", + "2", + ]) + dspy.settings.configure(lm=lm) + res = pot(question="What is 1+1?") + assert res.answer == "2" + assert lm.get_convo(index=-1) == textwrap.dedent("""\ + Given the final code `question`, `final_generated_code`, `code_output`, provide the final `answer`. + + --- + + Follow the following format. + + Question: ${question} + + Code: python code that answers the question + + Code Output: output of previously-generated python code + + Reasoning: Let's think step by step in order to ${produce the answer}. We ... + + Answer: often between 1 and 5 words + + --- + + Question: What is 1+1? + + Code: result = 1+1 + + Code Output: 2 + + Reasoning: Let's think step by step in order to Reason_B + + Answer: 2""") + +def test_pot_code_generation_with_error(): + pot = ProgramOfThought(BasicQA) + lm = DummyLM([ + "Reason_A", + "```python\nresult = 1+0/0\n```", + "Reason_B", # Error: division by zero + "```python\nresult = 1+1\n```", + "Reason_C", + "2", + ]) + dspy.settings.configure(lm=lm) + res = pot(question="What is 1+1?") + assert res.answer == "2" + + # The first code example failed + assert lm.get_convo(index=2) == textwrap.dedent("""\ + You are given `question`, `previous_code`, `error` due to an error in previous code. + Your task is to correct the error and provide the new `generated_code`. + + --- + + Follow the following format. + + Question: ${question} + + Previous Code: previously-generated python code that errored + + Error: error message from previously-generated python code + + Reasoning: Let's think step by step in order to ${produce the generated_code}. We ... + + Code: python code that answers the question + + --- + + Question: What is 1+1? + + Previous Code: result = 1+0/0 + + Error: division by zero + + Reasoning: Let's think step by step in order to Reason_B""") + + # The second code example succeeded + assert lm.get_convo(-1) == textwrap.dedent("""\ + Given the final code `question`, `final_generated_code`, `code_output`, provide the final `answer`. + + --- + + Follow the following format. + + Question: ${question} + + Code: python code that answers the question + + Code Output: output of previously-generated python code + + Reasoning: Let's think step by step in order to ${produce the answer}. We ... + + Answer: often between 1 and 5 words + + --- + + Question: What is 1+1? + + Code: result = 1+1 + + Code Output: 2 + + Reasoning: Let's think step by step in order to Reason_C + + Answer: 2""") diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py new file mode 100644 index 0000000000..f28e905e70 --- /dev/null +++ b/tests/predict/test_react.py @@ -0,0 +1,86 @@ +import dspy +from dspy.utils.dummies import dummy_rm + + +def test_example_no_tools(): + # Createa a simple dataset which the model will use with the Retrieve tool. + lm = dspy.utils.DummyLM( + [ + "Initial thoughts", # Thought_1 + "Finish[blue]", # Action_1 + ] + ) + dspy.settings.configure(lm=lm, rm=dummy_rm()) + + program = dspy.ReAct("question -> answer") + + # Check default tools + assert isinstance(program.tools["Finish"], dspy.Example) + + # Call the ReAct module on a particular input + question = "What is the color of the sky?" + result = program(question=question) + assert result.answer == "blue" + + # For debugging + print("---") + for row in lm.history: + print(row["prompt"]) + print("Response:", row["response"]["choices"][0]["text"]) + print("---") + + assert lm.get_convo(-1).endswith( + "Question: What is the color of the sky?\n" + "Thought 1: Initial thoughts\n" + "Action 1: Finish[blue]" + ) + + +def test_example_search(): + # Createa a simple dataset which the model will use with the Retrieve tool. + lm = dspy.utils.DummyLM( + [ + "Initial thoughts", # Thought_1 + "Search[the color of the sky]", # Thought_1 + "More thoughts", # Thought_2 + "Finish[blue]", # Action_2 + ] + ) + rm = dummy_rm( + [ + "We all know the color of the sky is blue.", + "Somethng about the sky colors", + "This sentence is completely irellevant to answer the question.", + "Let's add some more sentences to act as summy passages.", + "Let's add some more sentences to act as summy passages.", + "Let's add some more sentences to act as summy passages.", + ] + ) + dspy.settings.configure(lm=lm, rm=rm) + + program = dspy.ReAct("question -> answer") + + # Check default tools + assert len(program.tools) == 2 + assert isinstance(program.tools["Search"], dspy.Retrieve) + assert isinstance(program.tools["Finish"], dspy.Example) + + # Call the ReAct module on a particular input + question = "What is the color of the sky?" + result = program(question=question) + assert result.answer == "blue" + + # For debugging + print(lm.get_convo(-1)) + + assert lm.get_convo(-1).endswith( + "Question: What is the color of the sky?\n\n" + "Thought 1: Initial thoughts\n\n" + "Action 1: Search[the color of the sky]\n\n" + "Observation 1:\n" + "[1] «We all know the color of the sky is blue.»\n" + "[2] «Somethng about the sky colors»\n" + "[3] «This sentence is completely irellevant to answer the question.»\n\n" + "Thought 2: More thoughts\n\n" + "Action 2: Finish[blue]" + ) diff --git a/tests/predict/test_retry.py b/tests/predict/test_retry.py new file mode 100644 index 0000000000..a125dde296 --- /dev/null +++ b/tests/predict/test_retry.py @@ -0,0 +1,66 @@ +import functools +import dspy +from dspy.utils import DummyLM +from dspy.primitives.assertions import assert_transform_module, backtrack_handler + + +def test_retry_simple(): + predict = dspy.Predict("question -> answer") + retry_module = dspy.Retry(predict) + + # Test Retry has created the correct new signature + for field in predict.signature.output_fields: + assert f"past_{field}" in retry_module.new_signature.input_fields + assert "feedback" in retry_module.new_signature.input_fields + + lm = DummyLM(["blue"]) + dspy.settings.configure(lm=lm) + result = retry_module.forward( + question="What color is the sky?", + past_outputs={"answer": "red"}, + feedback="Try harder", + ) + assert result.answer == "blue" + + print(lm.get_convo(-1)) + assert lm.get_convo(-1).endswith( + "Question: What color is the sky?\n\n" + "Past Answer: red\n\n" + "Instructions: Try harder\n\n" + "Answer: blue" + ) + + +def test_retry_forward_with_feedback(): + # First we make a mistake, then we fix it + lm = DummyLM(["red", "blue"]) + dspy.settings.configure(lm=lm, trace=[]) + + class SimpleModule(dspy.Module): + def __init__(self): + super().__init__() + self.predictor = dspy.Predict("question -> answer") + + def forward(self, **kwargs): + result = self.predictor(**kwargs) + print(f"SimpleModule got {result.answer=}") + dspy.Suggest(result.answer == "blue", "Please think harder") + return result + + program = SimpleModule() + program = assert_transform_module( + program.map_named_predictors(dspy.Retry), + functools.partial(backtrack_handler, max_backtracks=1), + ) + + result = program(question="What color is the sky?") + + assert result.answer == "blue" + + print(lm.get_convo(-1)) + assert lm.get_convo(-1).endswith( + "Question: What color is the sky?\n\n" + "Past Answer: red\n\n" + "Instructions: Please think harder\n\n" + "Answer: blue" + ) diff --git a/tests/primitives/test_example.py b/tests/primitives/test_example.py new file mode 100644 index 0000000000..2f27996a24 --- /dev/null +++ b/tests/primitives/test_example.py @@ -0,0 +1,108 @@ +import pytest +from dspy import Example + + +def test_example_initialization(): + example = Example(a=1, b=2) + assert example.a == 1 + assert example.b == 2 + + +def test_example_initialization_from_base(): + base = Example(a=1, b=2) + example = Example(base=base, c=3) + assert example.a == 1 + assert example.b == 2 + assert example.c == 3 + + +def test_example_initialization_from_dict(): + base_dict = {"a": 1, "b": 2} + example = Example(base=base_dict, c=3) + assert example.a == 1 + assert example.b == 2 + assert example.c == 3 + + +def test_example_set_get_item(): + example = Example() + example["a"] = 1 + assert example["a"] == 1 + + +def test_example_attribute_access(): + example = Example(a=1) + assert example.a == 1 + example.a = 2 + assert example.a == 2 + + +def test_example_deletion(): + example = Example(a=1, b=2) + del example["a"] + with pytest.raises(AttributeError): + _ = example.a + + +def test_example_len(): + example = Example(a=1, b=2, dspy_hidden=3) + assert len(example) == 2 + + +def test_example_repr_str(): + example = Example(a=1) + assert repr(example) == "Example({'a': 1}) (input_keys=None)" + assert str(example) == "Example({'a': 1}) (input_keys=None)" + + +def test_example_eq(): + example1 = Example(a=1, b=2) + example2 = Example(a=1, b=2) + assert example1 == example2 + + +def test_example_hash(): + example1 = Example(a=1, b=2) + example2 = Example(a=1, b=2) + assert hash(example1) == hash(example2) + + +def test_example_keys_values_items(): + example = Example(a=1, b=2, dspy_hidden=3) + assert set(example.keys()) == {"a", "b"} + assert 1 in example.values() + assert ("b", 2) in example.items() + + +def test_example_get(): + example = Example(a=1, b=2) + assert example.get("a") == 1 + assert example.get("c", "default") == "default" + + +def test_example_with_inputs(): + example = Example(a=1, b=2).with_inputs("a") + assert example._input_keys == {"a"} + + +def test_example_inputs_labels(): + example = Example(a=1, b=2).with_inputs("a") + inputs = example.inputs() + assert inputs.toDict() == {"a": 1} + labels = example.labels() + assert labels.toDict() == {"b": 2} + + +def test_example_copy_without(): + example = Example(a=1, b=2) + copied = example.copy(c=3) + assert copied.a == 1 + assert copied.c == 3 + without_a = copied.without("a") + with pytest.raises(AttributeError): + _ = without_a.a + + +def test_example_to_dict(): + example = Example(a=1, b=2) + assert example.toDict() == {"a": 1, "b": 2} diff --git a/tests/primitives/test_program.py b/tests/primitives/test_program.py new file mode 100644 index 0000000000..b1d7c89725 --- /dev/null +++ b/tests/primitives/test_program.py @@ -0,0 +1,66 @@ +import dspy +from dspy.primitives.program import ( + Module, + set_attribute_by_name, +) # Adjust the import based on your file structure +from dspy.utils import DummyLM + + +class HopModule(dspy.Module): + def __init__(self): + super().__init__() + self.predict1 = dspy.Predict("question -> query") + self.predict2 = dspy.Predict("query -> answer") + + def forward(self, question): + query = self.predict1(question=question).query + return self.predict2(query=query) + + +def test_module_initialization(): + module = Module() + assert ( + module._compiled is False + ), "Module _compiled attribute should be False upon initialization" + + +def test_named_predictors(): + module = HopModule() + named_preds = module.named_predictors() + assert len(named_preds) == 2, "Should identify correct number of Predict instances" + names, preds = zip(*named_preds) + assert ( + "predict1" in names and "predict2" in names + ), "Named predictors should include 'predict1' and 'predict2'" + + +def test_predictors(): + module = HopModule() + preds = module.predictors() + assert len(preds) == 2, "Should return correct number of Predict instances" + assert all( + isinstance(p, dspy.Predict) for p in preds + ), "All returned items should be instances of PredictMock" + + +def test_forward(): + program = HopModule() + dspy.settings.configure( + lm=DummyLM({"What is 1+1?": "let me check", "let me check": "2"}) + ) + result = program(question="What is 1+1?").answer + assert result == "2" + + +def test_nested_named_predictors(): + class Hop2Module(dspy.Module): + def __init__(self): + super().__init__() + self.hop = HopModule() + + module = Hop2Module() + named_preds = module.named_predictors() + assert len(named_preds) == 2 + names, _preds = zip(*named_preds) + assert "hop.predict1" in names + assert "hop.predict2" in names diff --git a/tests/primitives/test_python_interpreter.py b/tests/primitives/test_python_interpreter.py new file mode 100644 index 0000000000..14b15d5572 --- /dev/null +++ b/tests/primitives/test_python_interpreter.py @@ -0,0 +1,44 @@ +import pytest +from dspy.primitives.python_interpreter import PythonInterpreter, TextPrompt, CodePrompt + +def test_execute_simple_code(): + interpreter = PythonInterpreter(action_space={'print': print}) + code = "print('Hello, World!')" + result = interpreter.execute(code) + assert result is None, "Simple print statement should return None" + +def test_action_space_limitation(): + def func(string): + pass + interpreter = PythonInterpreter(action_space={}) + code = "func('This should not execute')" + with pytest.raises(Exception): + interpreter.execute(code) + +def test_import_whitelist(): + interpreter = PythonInterpreter(action_space={}, import_white_list=['math']) + code = "import math\nresult = math.sqrt(4)" + result = interpreter.execute(code) + assert result == 2, "Should be able to import and use math.sqrt" + +def test_fuzzy_variable_matching(): + interpreter = PythonInterpreter(action_space={}) + code = "result = number + 1" + result = interpreter.execute(code, fuzz_state={'number': 4}) + assert result == 5, "Fuzzy variable matching should work" + +def test_text_prompt_keyword_extraction(): + prompt = TextPrompt("Hello {name}, how are you?") + assert 'name' in prompt.key_words, "Keyword 'name' should be extracted" + +def test_text_prompt_formatting(): + prompt = TextPrompt("Hello {name}, how are you?") + formatted = prompt.format(name="Alice") + assert formatted == "Hello Alice, how are you?", "Should format with provided value" + +def test_code_prompt_execution(): + action_space = {'len': len} + interpreter = PythonInterpreter(action_space=action_space) + code_prompt = CodePrompt("result = len('hello')") + result, _ = code_prompt.execute(interpreter) + assert result == 5, "Code execution should return the length of 'hello'" diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py new file mode 100644 index 0000000000..b093258540 --- /dev/null +++ b/tests/signatures/test_signature.py @@ -0,0 +1,166 @@ +import pytest +import pydantic +from dspy import Signature, infer_prefix, InputField, OutputField +from typing import List + + +def test_field_types_and_custom_attributes(): + class TestSignature(Signature): + """Instructions""" + + input1: str = InputField() + input2: int = InputField() + output1: List[str] = OutputField() + output2 = OutputField() + + assert TestSignature.instructions == "Instructions" + assert TestSignature.input_fields["input1"].annotation == str + assert TestSignature.input_fields["input2"].annotation == int + assert TestSignature.output_fields["output1"].annotation == List[str] + assert TestSignature.output_fields["output2"].annotation == str + + +def test_no_input_output(): + with pytest.raises(TypeError): + + class TestSignature(Signature): + input1: str + + +def test_no_input_output2(): + with pytest.raises(TypeError): + + class TestSignature(Signature): + input1: str = pydantic.Field() + + +def test_all_fields_have_prefix(): + class TestSignature(Signature): + input = InputField(prefix="Modified:") + output = OutputField() + + assert ( + TestSignature.input_fields["input"].json_schema_extra["prefix"] == "Modified:" + ) + assert ( + TestSignature.output_fields["output"].json_schema_extra["prefix"] == "Output:" + ) + + +def test_signature_parsing(): + signature = Signature("input1, input2 -> output") + assert "input1" in signature.input_fields + assert "input2" in signature.input_fields + assert "output" in signature.output_fields + + +def test_with_signature(): + signature1 = Signature("input1, input2 -> output") + signature2 = signature1.with_instructions("This is a test") + assert signature2.instructions == "This is a test" + assert signature1 is not signature2, "The type should be immutable" + + +def test_with_updated_field(): + signature1 = Signature("input1, input2 -> output") + signature2 = signature1.with_updated_fields("input1", prefix="Modified:") + assert signature2.input_fields["input1"].json_schema_extra["prefix"] == "Modified:" + assert signature1.input_fields["input1"].json_schema_extra["prefix"] == "Input 1:" + assert signature1 is not signature2, "The type should be immutable" + for key in signature1.fields.keys(): + if key != "input1": + assert ( + signature1.fields[key].json_schema_extra + == signature2.fields[key].json_schema_extra + ) + assert signature1.instructions == signature2.instructions + + +def test_empty_signature(): + with pytest.raises(ValueError): + Signature("") + + +def test_instructions_signature(): + with pytest.raises(ValueError): + Signature("") + + +def test_signature_instructions(): + sig1 = Signature("input1 -> output1", instructions="This is a test") + assert sig1.instructions == "This is a test" + + +def test_signature_instructions_none(): + sig1 = Signature("a, b -> c") + assert sig1.instructions == f"Given the fields `a`, `b`, produce the fields `c`." + + +def test_signature_from_dict(): + signature = Signature( + {"input1": InputField(), "input2": InputField(), "output": OutputField()} + ) + for k in ["input1", "input2", "output"]: + assert k in signature.fields + assert signature.fields[k].annotation == str + + +def test_signature_from_dict(): + signature = Signature( + {"input1": InputField(), "input2": InputField(), "output": OutputField()} + ) + assert "input1" in signature.input_fields + assert "input2" in signature.input_fields + assert "output" in signature.output_fields + + +def test_signature_equality(): + sig1 = Signature("input1 -> output1") + sig2 = Signature("input1 -> output1") + assert sig1.equals(sig2) + + +def test_signature_inequality(): + sig1 = Signature("input1 -> output1") + sig2 = Signature("input2 -> output2") + assert not sig1.equals(sig2) + + +def test_equality_format(): + class TestSignature(Signature): + input = InputField(format=lambda x: x) + output = OutputField() + + assert TestSignature.equals(TestSignature) + + +def test_signature_reverse(): + sig = Signature("input1 -> output1") + assert sig.signature == "input1 -> output1" + + +def test_insert_field_at_various_positions(): + class InitialSignature(Signature): + input1: str = InputField() + output1: int = OutputField() + + S1 = InitialSignature.prepend("new_input_start", InputField(), str) + S2 = InitialSignature.append("new_input_end", InputField(), str) + assert "new_input_start" == list(S1.input_fields.keys())[0] + assert "new_input_end" == list(S2.input_fields.keys())[-1] + + S3 = InitialSignature.prepend("new_output_start", OutputField(), str) + S4 = InitialSignature.append("new_output_end", OutputField(), str) + assert "new_output_start" == list(S3.output_fields.keys())[0] + assert "new_output_end" == list(S4.output_fields.keys())[-1] + + +def test_infer_prefix(): + assert infer_prefix("someAttributeName42IsCool") == "Some Attribute Name 42 Is Cool" + assert infer_prefix("version2Update") == "Version 2 Update" + assert infer_prefix("modelT45Enhanced") == "Model T 45 Enhanced" + assert infer_prefix("someAttributeName") == "Some Attribute Name" + assert infer_prefix("some_attribute_name") == "Some Attribute Name" + assert infer_prefix("URLAddress") == "URL Address" + assert infer_prefix("isHTTPSecure") == "Is HTTP Secure" + assert infer_prefix("isHTTPSSecure123") == "Is HTTPS Secure 123" diff --git a/tests/teleprompt/test_bootstrap.py b/tests/teleprompt/test_bootstrap.py new file mode 100644 index 0000000000..4758a5aae4 --- /dev/null +++ b/tests/teleprompt/test_bootstrap.py @@ -0,0 +1,180 @@ +import pytest +import dspy +from dspy.predict import Predict +from dspy.utils.dummies import DummyLM +from dspy import Example +from dspy.teleprompt import BootstrapFewShot +import textwrap + + +# Define a simple metric function for testing +def simple_metric(example, prediction, trace=None): + # Simplified metric for testing: true if prediction matches expected output + return example.output == prediction.output + + +examples = [ + Example(input="What is the color of the sky?", output="blue").with_inputs("input"), + Example( + input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!" + ), +] +trainset = [examples[0]] +valset = [examples[1]] + + +def test_bootstrap_initialization(): + # Initialize BootstrapFewShot with a dummy metric and minimal setup + bootstrap = BootstrapFewShot( + metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1 + ) + assert bootstrap.metric == simple_metric, "Metric not correctly initialized" + + +class SimpleModule(dspy.Module): + def __init__(self, signature): + super().__init__() + self.predictor = Predict(signature) + + def forward(self, **kwargs): + return self.predictor(**kwargs) + + +def test_compile_with_predict_instances(): + # Create Predict instances for student and teacher + # Note that dspy.Predict is not itself a module, so we can't use it directly here + student = SimpleModule("input -> output") + teacher = SimpleModule("input -> output") + + lm = DummyLM(["Initial thoughts", "Finish[blue]"]) + dspy.settings.configure(lm=lm) + + # Initialize BootstrapFewShot and compile the student + bootstrap = BootstrapFewShot( + metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1 + ) + compiled_student = bootstrap.compile( + student, teacher=teacher, trainset=trainset, valset=valset + ) + + assert compiled_student is not None, "Failed to compile student" + assert ( + hasattr(compiled_student, "_compiled") and compiled_student._compiled + ), "Student compilation flag not set" + + +def test_bootstrap_effectiveness(): + # This test verifies if the bootstrapping process improves the student's predictions + student = SimpleModule("input -> output") + teacher = SimpleModule("input -> output") + lm = DummyLM(["blue", "Ring-ding-ding-ding-dingeringeding!"], follow_examples=True) + dspy.settings.configure(lm=lm, trace=[]) + + bootstrap = BootstrapFewShot( + metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1 + ) + compiled_student = bootstrap.compile( + student, teacher=teacher, trainset=trainset, valset=valset + ) + + # Check that the compiled student has the correct demos + assert len(compiled_student.predictor.demos) == 1 + assert compiled_student.predictor.demos[0].input == trainset[0].input + assert compiled_student.predictor.demos[0].output == trainset[0].output + + # Test the compiled student's prediction. + # We are using a DummyLM with follow_examples=True, which means that + # even though it would normally reply with "Ring-ding-ding-ding-dingeringeding!" + # on the second output, if it seems an example that perfectly matches the + # prompt, it will use that instead. That is why we expect "blue" here. + prediction = compiled_student(input=trainset[0].input) + assert prediction.output == trainset[0].output + + # For debugging + print("Convo") + print(lm.get_convo(-1)) + + assert lm.get_convo(-1) == textwrap.dedent( + """\ + Given the fields `input`, produce the fields `output`. + + --- + + Follow the following format. + + Input: ${input} + Output: ${output} + + --- + + Input: What is the color of the sky? + Output: blue + + --- + + Input: What is the color of the sky? + Output: blue""" + ) + + +def test_error_handling_during_bootstrap(): + """ + Test to verify error handling during the bootstrapping process + """ + + class BuggyModule(dspy.Module): + def __init__(self, signature): + super().__init__() + self.predictor = Predict(signature) + + def forward(self, **kwargs): + raise RuntimeError("Simulated error") + + student = SimpleModule("input -> output") + teacher = BuggyModule("input -> output") + + # Setup DummyLM to simulate an error scenario + lm = DummyLM( + [ + "Initial thoughts", # Simulate initial teacher's prediction + ] + ) + dspy.settings.configure(lm=lm) + + bootstrap = BootstrapFewShot( + metric=simple_metric, + max_bootstrapped_demos=1, + max_labeled_demos=1, + max_errors=1, + ) + + with pytest.raises(RuntimeError, match="Simulated error"): + bootstrap.compile(student, teacher=teacher, trainset=trainset, valset=valset) + + +def test_validation_set_usage(): + """ + Test to ensure the validation set is correctly used during bootstrapping + """ + student = SimpleModule("input -> output") + teacher = SimpleModule("input -> output") + + lm = DummyLM( + [ + "Initial thoughts", + "Finish[blue]", # Expected output for both training and validation + ] + ) + dspy.settings.configure(lm=lm) + + bootstrap = BootstrapFewShot( + metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1 + ) + compiled_student = bootstrap.compile( + student, teacher=teacher, trainset=trainset, valset=valset + ) + + # Check that validation examples are part of student's demos after compilation + assert len(compiled_student.predictor.demos) >= len( + valset + ), "Validation set not used in compiled student demos" diff --git a/tests/teleprompt/test_ensemble.py b/tests/teleprompt/test_ensemble.py new file mode 100644 index 0000000000..292176af4f --- /dev/null +++ b/tests/teleprompt/test_ensemble.py @@ -0,0 +1,60 @@ +import pytest +import dspy +from dspy.teleprompt.ensemble import Ensemble + + +class MockProgram(dspy.Module): + def __init__(self, output): + super().__init__() + self.output = output + + def forward(self, *args, **kwargs): + return self.output + + +# Simple reduction function to test with +def mock_reduce_fn(outputs): + return sum(outputs) / len(outputs) + + +def test_ensemble_without_reduction(): + """Test that Ensemble correctly combines outputs without applying a reduce_fn.""" + programs = [MockProgram(i) for i in range(5)] + ensemble = Ensemble() + ensembled_program = ensemble.compile(programs) + + outputs = ensembled_program() + assert len(outputs) == 5, "Ensemble did not combine the correct number of outputs" + + +def test_ensemble_with_reduction(): + """Test that Ensemble correctly applies a reduce_fn to combine outputs.""" + programs = [MockProgram(i) for i in range(5)] + ensemble = Ensemble(reduce_fn=mock_reduce_fn) + ensembled_program = ensemble.compile(programs) + + output = ensembled_program() + expected_output = sum(range(5)) / 5 + assert output == expected_output, "Ensemble did not correctly apply the reduce_fn" + + +def test_ensemble_with_size_limitation(): + """Test that specifying a size limits the number of programs used in the ensemble.""" + programs = [MockProgram(i) for i in range(10)] + ensemble_size = 3 + ensemble = Ensemble(size=ensemble_size) + ensembled_program = ensemble.compile(programs) + + outputs = ensembled_program() + assert ( + len(outputs) == ensemble_size + ), "Ensemble did not respect the specified size limitation" + + +def test_ensemble_deterministic_behavior(): + """Verify that the Ensemble class raises an assertion for deterministic behavior.""" + with pytest.raises( + AssertionError, + match="TODO: Implement example hashing for deterministic ensemble.", + ): + Ensemble(deterministic=True) diff --git a/tests/teleprompt/test_finetune.py b/tests/teleprompt/test_finetune.py new file mode 100644 index 0000000000..f87f5c14cb --- /dev/null +++ b/tests/teleprompt/test_finetune.py @@ -0,0 +1 @@ +# TODO \ No newline at end of file diff --git a/tests/teleprompt/test_knn_fewshot.py b/tests/teleprompt/test_knn_fewshot.py new file mode 100644 index 0000000000..b267d3dce8 --- /dev/null +++ b/tests/teleprompt/test_knn_fewshot.py @@ -0,0 +1,72 @@ +import pytest +import dsp, dspy +from dspy.predict.knn import KNN +from dspy.teleprompt.knn_fewshot import KNNFewShot +from dspy.utils.dummies import DummyLM, DummyVectorizer + + +def mock_example(question: str, answer: str) -> dsp.Example: + """Creates a mock DSP example with specified question and answer.""" + return dspy.Example(question=question, answer=answer).with_inputs("question") + + +@pytest.fixture +def setup_knn_few_shot(): + """Sets up a KNNFewShot instance for testing.""" + trainset = [ + mock_example("What is the capital of France?", "Paris"), + mock_example("What is the largest ocean?", "Pacific"), + mock_example("What is 2+2?", "4"), + ] + dsp.SentenceTransformersVectorizer = DummyVectorizer + knn_few_shot = KNNFewShot(KNN, k=2, trainset=trainset) + return knn_few_shot + + +def test_knn_few_shot_initialization(setup_knn_few_shot): + """Tests the KNNFewShot initialization.""" + knn_few_shot = setup_knn_few_shot + assert knn_few_shot.KNN.k == 2, "Incorrect k value for KNN" + assert len(knn_few_shot.KNN.trainset) == 3, "Incorrect trainset size for KNN" + + +class SimpleModule(dspy.Module): + def __init__(self, signature): + super().__init__() + self.predictor = dspy.Predict(signature) + + def forward(self, *args, **kwargs): + return self.predictor(**kwargs) + + def reset_copy(self): + # Creates a new instance of SimpleModule with the same predictor + return SimpleModule(self.predictor.signature) + + +# TODO: Test not working yet +def _test_knn_few_shot_compile(setup_knn_few_shot): + """Tests the compile method of KNNFewShot with SimpleModule as student.""" + student = SimpleModule("input -> output") + teacher = SimpleModule("input -> output") # Assuming teacher uses the same module type + + # Setup DummyLM with a response for a query similar to one of the training examples + lm = DummyLM(["Madrid", "10"]) + dspy.settings.configure(lm=lm) # Responses for the capital of Spain and the result of 5+5) + + knn_few_shot = setup_knn_few_shot + trainset = knn_few_shot.KNN.trainset + compiled_student = knn_few_shot.compile(student, teacher=teacher, trainset=trainset, valset=None) + + assert len(compiled_student.predictor.demos) == 1 + assert compiled_student.predictor.demos[0].input == trainset[0].input + assert compiled_student.predictor.demos[0].output == trainset[0].output + + # Simulate a query that is similar to one of the training examples + output = compiled_student.forward(input = "What is the capital of Spain?").output + + print("CONVO") + print(lm.get_convo(-1)) + + # Validate that the output corresponds to one of the expected DummyLM responses + # This assumes the compiled_student's forward method will execute the predictor with the given query + assert output in ["Madrid", "10"], "The compiled student did not return the correct output based on the query" diff --git a/tests/teleprompt/test_signature_opt.py b/tests/teleprompt/test_signature_opt.py new file mode 100644 index 0000000000..d7f3475514 --- /dev/null +++ b/tests/teleprompt/test_signature_opt.py @@ -0,0 +1,121 @@ +import textwrap +import dspy +from dspy.teleprompt.signature_opt import SignatureOptimizer +from dspy.utils.dummies import DummyLM +from dspy import Example + +# Define a simple metric function for testing +def simple_metric(example, prediction): + # Simplified metric for testing: true if prediction matches expected output + return example.output == prediction.output + +# Example training and validation sets +trainset = [ + Example(input="Question: What is the color of the sky?", output="blue").with_inputs("input"), + Example(input="Question: What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!").with_inputs("input"), +] + +def test_signature_optimizer_initialization(): + optimizer = SignatureOptimizer(metric=simple_metric, breadth=2, depth=1, init_temperature=1.4) + assert optimizer.metric == simple_metric, "Metric not correctly initialized" + assert optimizer.breadth == 2, "Breadth not correctly initialized" + assert optimizer.depth == 1, "Depth not correctly initialized" + assert optimizer.init_temperature == 1.4, "Initial temperature not correctly initialized" + +class SimpleModule(dspy.Module): + def __init__(self, signature): + super().__init__() + # SignatureOptimizer doesn't work with dspy.Predict + self.predictor = dspy.ChainOfThought(signature) + + def forward(self, **kwargs): + return self.predictor(**kwargs) + +def test_signature_optimizer_optimization_process(): + optimizer = SignatureOptimizer(metric=simple_metric, breadth=2, depth=1, init_temperature=1.4) + dspy.settings.configure(lm=DummyLM(["Optimized instruction 1", "Optimized instruction 2"])) + + student = SimpleModule("input -> output") + + # Assuming the compile method of SignatureOptimizer requires a student module, a development set, and evaluation kwargs + optimized_student = optimizer.compile(student, devset=trainset, eval_kwargs={"num_threads": 1, "display_progress": False}) + + # Check that the optimized student has been modified from the original + # This check can be more specific based on how the optimization modifies the student + assert optimized_student is not student, "Optimization did not modify the student" + + # Further tests can be added to verify the specifics of the optimization process, + # such as checking the instructions of the optimized student's predictors. + +def test_signature_optimizer_statistics_tracking(): + optimizer = SignatureOptimizer(metric=simple_metric, breadth=2, depth=1, init_temperature=1.4) + optimizer.track_stats = True # Enable statistics tracking + + dspy.settings.configure(lm=DummyLM(["Optimized instruction"])) + student = SimpleModule("input -> output") + optimized_student = optimizer.compile(student, devset=trainset, eval_kwargs={"num_threads": 1, "display_progress": False}) + + # Verify that statistics have been tracked and attached to the optimized student + assert hasattr(optimized_student, 'total_calls'), "Total calls statistic not tracked" + assert hasattr(optimized_student, 'results_best'), "Best results statistics not tracked" + +# Assuming the setup_signature_optimizer fixture and simple_metric function are defined as before + +def test_optimization_and_output_verification(): + lm = DummyLM([ + "Optimized Prompt", + "Optimized Prefix", + ]) + dspy.settings.configure(lm=lm) + optimizer = SignatureOptimizer(metric=simple_metric, breadth=2, depth=1, init_temperature=1.4) + + student = SimpleModule("input -> output") + + # Compile the student with the optimizer + optimized_student = optimizer.compile(student, devset=trainset, eval_kwargs={"num_threads": 1, "display_progress": False}) + + # Simulate calling the optimized student with a new input + test_input = "What is the capital of France?" + prediction = optimized_student(input=test_input) + + print(lm.get_convo(-1)) + + assert prediction.output == "No more responses" + + assert lm.get_convo(-1) == textwrap.dedent("""\ + Optimized Prompt + + --- + + Follow the following format. + + Input: ${input} + Reasoning: Let's think step by step in order to ${produce the output}. We ... + Optimized Prefix ${output} + + --- + + Input: What is the capital of France? + Reasoning: Let's think step by step in order to No more responses + Optimized Prefix No more responses""") + +def test_statistics_tracking_during_optimization(): + dspy.settings.configure(lm=DummyLM(["Optimized instruction for stats tracking"])) + + optimizer = SignatureOptimizer(metric=simple_metric, breadth=2, depth=1, init_temperature=1.4) + optimizer.track_stats = True # Enable statistics tracking + + student = SimpleModule("input -> output") + optimized_student = optimizer.compile(student, devset=trainset, eval_kwargs={"num_threads": 1, "display_progress": False}) + + # Verify that statistics have been tracked + assert hasattr(optimized_student, 'total_calls'), "Optimizer did not track total metric calls" + assert optimized_student.total_calls > 0, "Optimizer reported no metric calls" + + # Check if the results_best and results_latest contain valid statistics + assert 'results_best' in optimized_student.__dict__, "Optimizer did not track the best results" + assert 'results_latest' in optimized_student.__dict__, "Optimizer did not track the latest results" + assert len(optimized_student.results_best) > 0, "Optimizer did not properly populate the best results statistics" + assert len(optimized_student.results_latest) > 0, "Optimizer did not properly populate the latest results statistics" + + # Additional detailed checks can be added here to verify the contents of the tracked statistics diff --git a/tests/teleprompt/test_signature_opt_bayesian.py b/tests/teleprompt/test_signature_opt_bayesian.py new file mode 100644 index 0000000000..0cf655784f --- /dev/null +++ b/tests/teleprompt/test_signature_opt_bayesian.py @@ -0,0 +1,261 @@ +import textwrap +import pytest +import re +import dspy +from dsp.modules import LM +from dspy.teleprompt.signature_opt_bayesian import BayesianSignatureOptimizer +from dspy.utils.dummies import DummyLM +from dspy import Example + + +# Define a simple metric function for testing +def simple_metric(example, prediction, trace=None): + # Simplified metric for testing: true if prediction matches expected output + return example.output == prediction.output + +# Some example data +capitals = { + "Germany": "Berlin", + "France": "Paris", + "Denmark": "Copenhagen", + "Sweden": "Stockholm", + "Norway": "Oslo", +} +# Not used for training data +extra_capitals = { + "Spain": "Madrid", + "Portugal": "Lisbon", + "Italy": "Rome", +} + +# Example training and validation sets +trainset = [ + Example(input="What is the color of the sky?", output="blue").with_inputs("input"), + Example( + input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!" + ).with_inputs("input"), +] + [Example(input=f"What is the capital of {country}?", output=capital).with_inputs("input") for country, capital in capitals.items()] + + +class ConditionalLM(LM): + def __init__(self): + super().__init__("conditional-lm") + + def basic_request(self, prompt, n=1, **kwargs): + # If we are in the "optimization" stage, we don't say much. + if prompt.endswith("Observations:"): + answer = " (*silence*)" + elif prompt.endswith("Proposed Instruction:"): + answer = " Input: " + elif prompt.endswith("Proposed Prefix For Output Field:"): + answer = " Output: " + elif prompt.endswith("Summary:"): + answer = " summarizing..." + else: + pairs = re.findall(r"Input: (.*)\nOutput: (.*)", prompt) + + print("PROMPT:", prompt) + print("PAIRS:", pairs) + + last = re.search(r"Input: (.*)\nReasoning: (.*)$", prompt) + current_question = last.group(1) + + if match := re.match(r"What is the capital of (.*?)\?", current_question): + country = match.group(1) + # If we had a previous example of a question about a capital, the model + # has learned the format, and will answer with question correctly. + if any("capital" in question for question, _ in pairs): + answer = (capitals | extra_capitals)[country] + # Otherwise, it is confused and will answer with the country's name. + else: + answer = country + # For other questions, the model will answer with the last word of the question. + else: + answer = current_question.split()[-1] + + answer = "think deeply.\nOutput: " + answer + + RED, GREEN, RESET = '\033[91m', '\033[92m', '\033[0m' + print("=== DummyLM ===") + print(prompt, end="") + print(f"{RED}{answer}{RESET}") + print("===") + + dummy_response = {"choices": []} + for _ in range(n): + dummy_response["choices"].append( + { + "text": answer, + "finish_reason": "done", + } + ) + + # Simulate processing and storing the request and response. + history_entry = { + "prompt": prompt, + "response": dummy_response, + "kwargs": kwargs, + "raw_kwargs": kwargs, + } + self.history.append(history_entry) + + return dummy_response + + def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): + response = self.basic_request(prompt, **kwargs) + return [choice["text"] for choice in response["choices"]] + + def get_convo(self, index): + """get the prompt + anwer from the ith message""" + return self.history[index]['prompt'] \ + + " " \ + + self.history[index]['response']['choices'][0]['text'] + + +def test_bayesian_signature_optimizer_initialization(): + optimizer = BayesianSignatureOptimizer( + metric=simple_metric, n=10, init_temperature=1.4, verbose=True, track_stats=True + ) + assert optimizer.metric == simple_metric, "Metric not correctly initialized" + assert optimizer.n == 10, "Incorrect 'n' parameter initialization" + assert ( + optimizer.init_temperature == 1.4 + ), "Initial temperature not correctly initialized" + assert optimizer.verbose is True, "Verbose flag not correctly initialized" + assert optimizer.track_stats is True, "Track stats flag not correctly initialized" + + +class SimpleModule(dspy.Module): + def __init__(self, signature): + super().__init__() + # SignatureOptimizer doesn't work with dspy.Predict + self.predictor = dspy.ChainOfThought(signature) + + def forward(self, **kwargs): + return self.predictor(**kwargs) + + +def test_signature_optimizer_optimization_process(): + lm = ConditionalLM() + dspy.settings.configure(lm=lm) + + student = SimpleModule(signature="input -> output") + + optimizer = BayesianSignatureOptimizer( + metric=simple_metric, + n=10, + init_temperature=1.4, + verbose=False, + track_stats=False, + ) + + # Adjustments: Include required parameters for the compile method + optimized_student = optimizer.compile( + student=student, + devset=trainset, + optuna_trials_num=10, + max_bootstrapped_demos=3, + max_labeled_demos=5, + eval_kwargs={"num_threads": 1, "display_progress": False}, + ) + + assert len(optimized_student.predictor.demos) == 5 + + +def test_signature_optimizer_bad_lm(): + dspy.settings.configure( + lm=DummyLM([f"Optimized instruction {i}" for i in range(30)]) + ) + student = SimpleModule(signature="input -> output") + optimizer = BayesianSignatureOptimizer( + metric=simple_metric, + n=10, + init_temperature=1.4, + verbose=False, + track_stats=False, + ) + + # Krista: when the code tries to generate bootstrapped examples, the examples are generated using DummyLM, + # which only outputs "Optimized instruction i" this means that none of the bootstrapped examples are successful, + # and therefore the set of examples that we're using to generate new prompts is empty + with pytest.raises(ValueError): + _optimized_student = optimizer.compile( + student=student, + devset=trainset, + optuna_trials_num=10, + max_bootstrapped_demos=3, + max_labeled_demos=5, + eval_kwargs={"num_threads": 1, "display_progress": False}, + ) + + +def test_optimization_and_output_verification(): + # Make a language model that is always right, except on the last + # example in the train set. + lm = ConditionalLM() + dspy.settings.configure(lm=lm) + + optimizer = BayesianSignatureOptimizer( + metric=simple_metric, + n=10, + init_temperature=1.4, + verbose=False, + track_stats=True, + ) + + student = SimpleModule("input -> output") + + # Compile the student with the optimizer + optimized_student = optimizer.compile( + student=student, + devset=trainset, + optuna_trials_num=4, + max_bootstrapped_demos=2, + max_labeled_demos=3, + eval_kwargs={"num_threads": 1, "display_progress": False}, + ) + + # Simulate calling the optimized student with a new input + test_input = "What is the capital of Spain?" + prediction = optimized_student(input=test_input) + + print("CORRECT ANSWER") + print(lm.get_convo(-1)) + + assert prediction.output == "Madrid" + + assert lm.get_convo(-1) == textwrap.dedent( + """\ + Input: + + --- + + Follow the following format. + + Input: ${input} + Reasoning: Let's think step by step in order to ${produce the output}. We ... + Output: ${output} + + --- + + Input: What is the capital of Norway? + Reasoning: Let's think step by step in order to think deeply. + Output: Oslo + + --- + + Input: What is the capital of Sweden? + Reasoning: Let's think step by step in order to think deeply. + Output: Stockholm + + --- + + Input: What is the capital of France? + Output: Paris + + --- + + Input: What is the capital of Spain? + Reasoning: Let's think step by step in order to think deeply. + Output: Madrid""" + )