diff --git a/docs/docs/tutorials/tool_use/index.ipynb b/docs/docs/tutorials/tool_use/index.ipynb new file mode 100644 index 0000000000..56f7dda22f --- /dev/null +++ b/docs/docs/tutorials/tool_use/index.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial: Advanced Tool Use\n", + "\n", + "Let's walk through a quick example of building and prompt-optimizing a DSPy agent for advanced tool use. We'll do this for the challenging task [ToolHop](https://arxiv.org/abs/2501.02506) but with an even stricter evaluation criteria.\n", + "\n", + "Install the latest DSPy via `pip install -U dspy` and follow along. You will also need to `pip install func_timeout`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Recommended: Set up MLflow Tracing to understand what's happening under the hood.\n", + "\n", + "### MLflow DSPy Integration\n", + "\n", + "MLflow is an LLMOps tool that natively integrates with DSPy and offer explainability and experiment tracking. In this tutorial, you can use MLflow to visualize prompts and optimization progress as traces to understand the DSPy's behavior better. You can set up MLflow easily by following the four steps below.\n", + "\n", + "1. Install MLflow\n", + "\n", + "```bash\n", + "%pip install mlflow>=2.20\n", + "```\n", + "\n", + "2. Start MLflow UI in a separate terminal\n", + "```bash\n", + "mlflow ui --port 5000\n", + "```\n", + "\n", + "3. Connect the notebook to MLflow\n", + "```python\n", + "import mlflow\n", + "\n", + "mlflow.set_tracking_uri(\"http://localhost:5000\")\n", + "mlflow.set_experiment(\"DSPy\")\n", + "```\n", + "\n", + "4. Enabling tracing.\n", + "```python\n", + "mlflow.dspy.autolog()\n", + "```\n", + "\n", + "To learn more about the integration, visit [MLflow DSPy Documentation](https://mlflow.org/docs/latest/llms/dspy/index.html) as well.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we'll demonstrate the new experimental `dspy.SIMBA` prompt optimizer, which tends to be powerful for larger LLMs and harder tasks. Using this, we'll improve our agent from 35% accuracy to 60%." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "import ujson\n", + "import random\n", + "\n", + "gpt4o = dspy.LM(\"openai/gpt-4o\", temperature=0.7)\n", + "dspy.configure(lm=gpt4o)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now download the data." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading 'ToolHop.json'...\n" + ] + } + ], + "source": [ + "from dspy.utils import download\n", + "\n", + "download(\"https://huggingface.co/datasets/bytedance-research/ToolHop/resolve/main/data/ToolHop.json\")\n", + "\n", + "data = ujson.load(open(\"ToolHop.json\"))\n", + "random.Random(0).shuffle(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then let's prepare a cleaned set of examples. The ToolHop task is interesting in that the agent gets a _unique set_ of tools (functions) to use separately for each request. Thus, it needs to learn how to use _any_ such tools effectively in practice." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "import inspect\n", + "\n", + "examples = []\n", + "fns2code = {}\n", + "\n", + "def finish(answer: str):\n", + " \"\"\"Conclude the trajectory and return the final answer.\"\"\"\n", + " return answer\n", + "\n", + "for datapoint in data:\n", + " func_dict = {}\n", + " for func_code in datapoint[\"functions\"]:\n", + " cleaned_code = func_code.rsplit(\"\\n\\n# Example usage\", 1)[0]\n", + " fn_name = re.search(r\"^\\s*def\\s+([a-zA-Z0-9_]+)\\s*\\(\", cleaned_code)\n", + " fn_name = fn_name.group(1) if fn_name else None\n", + "\n", + " if not fn_name:\n", + " continue\n", + "\n", + " local_vars = {}\n", + " exec(cleaned_code, {}, local_vars)\n", + " fn_obj = local_vars.get(fn_name)\n", + "\n", + " if callable(fn_obj):\n", + " func_dict[fn_name] = fn_obj\n", + " assert fn_obj not in fns2code, f\"Duplicate function found: {fn_name}\"\n", + " fns2code[fn_obj] = (fn_name, cleaned_code)\n", + "\n", + " func_dict[\"finish\"] = finish\n", + "\n", + " example = dspy.Example(question=datapoint[\"question\"], answer=datapoint[\"answer\"], functions=func_dict)\n", + " examples.append(example.with_inputs(\"question\", \"functions\"))\n", + "\n", + "trainset, devset, testset = examples[:100], examples[100:400], examples[400:]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And let's define some helpers for the task. Here, we will define the `metric`, which will be (much) stricter than in the original paper: we'll expect the prediction to match exactly (after normalization) with the ground truth. We'll also be strict in a second way: we'll only allow the agent to take 5 steps in total, to allow for efficient deployment." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from func_timeout import func_set_timeout\n", + "\n", + "def wrap_function_with_timeout(fn):\n", + " @func_set_timeout(10)\n", + " def wrapper(*args, **kwargs):\n", + " try:\n", + " return {\"return_value\": fn(*args, **kwargs), \"errors\": None}\n", + " except Exception as e:\n", + " return {\"return_value\": None, \"errors\": str(e)}\n", + "\n", + " return wrapper\n", + "\n", + "def fn_metadata(func):\n", + " signature = inspect.signature(func)\n", + " docstring = inspect.getdoc(func) or \"No docstring.\"\n", + " return dict(function_name=func.__name__, arguments=str(signature), docstring=docstring)\n", + "\n", + "def metric(example, pred, trace=None):\n", + " gold = str(example.answer).rstrip(\".0\").replace(\",\", \"\").lower()\n", + " pred = str(pred.answer).rstrip(\".0\").replace(\",\", \"\").lower()\n", + " return pred == gold # stricter than the original paper's metric!\n", + "\n", + "evaluate = dspy.Evaluate(devset=devset, metric=metric, num_threads=24, display_progress=True, display_table=0, max_errors=999)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's define the agent! The core of our agent will be based on a ReAct loop, in which the model sees the trajectory so far and the set of functions available to invoke, and decides the next tool to call.\n", + "\n", + "To keep the final agent fast, we'll limit its `max_steps` to 5 steps. We'll also run each function call with a timeout." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class Agent(dspy.Module):\n", + " def __init__(self, max_steps=5):\n", + " self.max_steps = max_steps\n", + " instructions = \"For the final answer, produce short (not full sentence) answers in which you format dates as YYYY-MM-DD, names as Firstname Lastname, and numbers without leading 0s.\"\n", + " signature = dspy.Signature('question, trajectory, functions -> next_selected_fn, args: dict[str, Any]', instructions)\n", + " self.react = dspy.ChainOfThought(signature)\n", + "\n", + " def forward(self, question, functions):\n", + " tools = {fn_name: fn_metadata(fn) for fn_name, fn in functions.items()}\n", + " trajectory = []\n", + "\n", + " for _ in range(self.max_steps):\n", + " pred = self.react(question=question, trajectory=trajectory, functions=tools)\n", + " selected_fn = pred.next_selected_fn.strip('\"').strip(\"'\")\n", + " fn_output = wrap_function_with_timeout(functions[selected_fn])(**pred.args)\n", + " trajectory.append(dict(reasoning=pred.reasoning, selected_fn=selected_fn, args=pred.args, **fn_output))\n", + "\n", + " if selected_fn == \"finish\":\n", + " break\n", + "\n", + " return dspy.Prediction(answer=fn_output.get(\"return_value\", ''), trajectory=trajectory)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Out of the box, let's assess our `GPT-4o`-powered agent on the development set." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025/03/23 21:46:10 INFO dspy.evaluate.evaluate: Average Metric: 105.0 / 300 (35.0%)\n" + ] + }, + { + "data": { + "text/plain": [ + "35.0" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent = Agent()\n", + "evaluate(agent)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's optimize the agent using `dspy.SIMBA`, which stands for **Stochastic Introspective Mini-Batch Ascent**. This prompt optimizer accepts arbitrary DSPy programs like our agent here and proceeds in a sequence of mini-batches seeking to make incremental improvements to the prompt instructions or few-shot examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "simba = dspy.SIMBA(metric=metric, max_steps=12, max_demos=10)\n", + "optimized_agent = simba.compile(agent, trainset=trainset, seed=6793115)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Having completed this optimization, let's now evaluate our agent again. We see a substantial 71% relative gain, jumping to 60% accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025/03/23 21:46:21 INFO dspy.evaluate.evaluate: Average Metric: 182.0 / 300 (60.7%)" + ] + }, + { + "data": { + "text/plain": [ + "60.67" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluate(optimized_agent)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jun2024_py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 59b741d6a3..0f4427659e 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -53,6 +53,7 @@ nav: - Classification Finetuning: tutorials/classification_finetuning/index.ipynb - Multi-Hop Search: tutorials/multihop_search/index.ipynb - Privacy-Conscious Delegation: tutorials/papillon/index.md + - Advanced Tool Use: tutorials/tool_use/index.ipynb - Image Generation Prompt iteration: tutorials/image_generation_prompting/index.ipynb - Output Refinement: tutorials/output_refinement/best-of-n-and-refine.md - Finetuning Agents: tutorials/games/index.ipynb diff --git a/dspy/primitives/prediction.py b/dspy/primitives/prediction.py index 2fe78c1c61..63b74873f9 100644 --- a/dspy/primitives/prediction.py +++ b/dspy/primitives/prediction.py @@ -29,6 +29,67 @@ def __repr__(self): def __str__(self): return self.__repr__() + + def __float__(self): + if "score" not in self._store: + raise ValueError("Prediction object does not have a 'score' field to convert to float.") + return float(self._store["score"]) + + def __add__(self, other): + if isinstance(other, (float, int)): + return self.__float__() + other + elif isinstance(other, Prediction): + return self.__float__() + float(other) + raise TypeError(f"Unsupported type for addition: {type(other)}") + + def __radd__(self, other): + if isinstance(other, (float, int)): + return other + self.__float__() + elif isinstance(other, Prediction): + return float(other) + self.__float__() + raise TypeError(f"Unsupported type for addition: {type(other)}") + + def __truediv__(self, other): + if isinstance(other, (float, int)): + return self.__float__() / other + elif isinstance(other, Prediction): + return self.__float__() / float(other) + raise TypeError(f"Unsupported type for division: {type(other)}") + + def __rtruediv__(self, other): + if isinstance(other, (float, int)): + return other / self.__float__() + elif isinstance(other, Prediction): + return float(other) / self.__float__() + raise TypeError(f"Unsupported type for division: {type(other)}") + + def __lt__(self, other): + if isinstance(other, (float, int)): + return self.__float__() < other + elif isinstance(other, Prediction): + return self.__float__() < float(other) + raise TypeError(f"Unsupported type for comparison: {type(other)}") + + def __le__(self, other): + if isinstance(other, (float, int)): + return self.__float__() <= other + elif isinstance(other, Prediction): + return self.__float__() <= float(other) + raise TypeError(f"Unsupported type for comparison: {type(other)}") + + def __gt__(self, other): + if isinstance(other, (float, int)): + return self.__float__() > other + elif isinstance(other, Prediction): + return self.__float__() > float(other) + raise TypeError(f"Unsupported type for comparison: {type(other)}") + + def __ge__(self, other): + if isinstance(other, (float, int)): + return self.__float__() >= other + elif isinstance(other, Prediction): + return self.__float__() >= float(other) + raise TypeError(f"Unsupported type for comparison: {type(other)}") @property def completions(self): diff --git a/dspy/teleprompt/__init__.py b/dspy/teleprompt/__init__.py index 1fe7c31400..3168cd1c44 100644 --- a/dspy/teleprompt/__init__.py +++ b/dspy/teleprompt/__init__.py @@ -5,6 +5,7 @@ from dspy.teleprompt.copro_optimizer import COPRO from dspy.teleprompt.ensemble import Ensemble from dspy.teleprompt.knn_fewshot import KNNFewShot +from dspy.teleprompt.simba import SIMBA from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2 from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch @@ -27,4 +28,5 @@ "BootstrapFewShotWithOptuna", "LabeledFewShot", "InferRules", + "SIMBA", ] diff --git a/dspy/teleprompt/simba.py b/dspy/teleprompt/simba.py new file mode 100644 index 0000000000..9d41c563ab --- /dev/null +++ b/dspy/teleprompt/simba.py @@ -0,0 +1,313 @@ +import dspy +import random +import logging + +import numpy as np +from typing import Callable +from dspy.teleprompt.teleprompt import Teleprompter +from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule + +logger = logging.getLogger(__name__) + + +# Stochastic Introspective Mini-Batch Ascent +class SIMBA(Teleprompter): + def __init__( + self, + *, + metric: Callable, + bsize=32, + num_candidates=6, + max_steps=8, + max_demos=4, + demo_input_field_maxlen=100_000, + num_threads=16, + temperature_for_sampling=0.2, + temperature_for_candidates=0.2, + ): + """ + :param metric: A function (Example, prediction_dict) -> float + :param bsize: mini-batch size + :param num_candidates: how many new candidate programs to produce per iteration + :param max_steps: how many optimization steps to run + :param max_demos: how many demos we allow a predictor to hold before we must drop some + :param demo_input_field_maxlen: how many characters of an input field to keep when building a new demo + :param num_threads: how many threads for run_parallel + :param temperature_for_sampling: temperature used for picking programs for the trajectory-sampling step + :param temperature_for_candidates: temperature used for picking the source program for building new candidates + """ + self.metric = metric + self.bsize = bsize + self.num_candidates = num_candidates + self.max_steps = max_steps + self.max_demos = max_demos + self.demo_input_field_maxlen = demo_input_field_maxlen + self.num_threads = num_threads + + self.temperature_for_sampling = temperature_for_sampling + self.temperature_for_candidates = temperature_for_candidates + + self.strategies = [append_a_demo(demo_input_field_maxlen), append_a_rule] + + def compile(self, student: dspy.Module, *, trainset: list[dspy.Example], seed: int = 0): + # Basic checks + assert len(trainset) >= self.bsize, f"Trainset too small: {len(trainset)} < {self.bsize}" + + # Initialize RNG + rng = random.Random(seed) + rng_np = np.random.default_rng(seed) + + programs = [] + program_scores = {} + next_program_idx = 0 + + # Helper functions + def calc_average_score(prog_idx: int) -> float: + scores = program_scores.get(prog_idx, []) + if not scores: + return 0.0 + return sum(scores) / len(scores) + + def top_k_plus_baseline(k: int) -> list[int]: + # Sort all programs by descending average score + scored_programs = sorted(programs, key=lambda p: calc_average_score(p.simba_idx), reverse=True) + top_k = [p.simba_idx for p in scored_programs[:k]] + # Ensure baseline=0 is in there: + if 0 not in top_k and len(top_k) > 0: + top_k[-1] = 0 + return list(dict.fromkeys(top_k)) + + def softmax_sample(rng_obj: random.Random, program_idxs: list[int], temperature: float) -> int: + if not program_idxs: + raise ValueError("No programs available for softmax sampling.") + + # Unnormalized weights + scores = [calc_average_score(idx) for idx in program_idxs] + exps = [np.exp(s / temperature) for s in scores] + sum_exps = sum(exps) + if sum_exps <= 0: + # Fallback: uniform if all exps are zero + return rng_obj.choice(program_idxs) + + # Weighted random choice + probs = [val / sum_exps for val in exps] + return rng_obj.choices(program_idxs, weights=probs, k=1)[0] + + def register_new_program(prog: dspy.Module, score_list: list[float]): + nonlocal next_program_idx + next_program_idx += 1 + new_idx = next_program_idx + prog.simba_idx = new_idx + programs.append(prog) + program_scores[new_idx] = score_list + + # Initialize the baseline program: index=0 + student = student.deepcopy() + student.simba_idx = 0 + programs.append(student) + program_scores[0] = [] + + winning_programs = [student] + + # Data shuffling + data_indices = list(range(len(trainset))) + rng.shuffle(data_indices) + instance_idx = 0 + + # Parallel runner + run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads) + + for batch_idx in range(self.max_steps): + logger.info(f"Starting batch {batch_idx+1} of {self.max_steps}.") + + # STEP 1: Get next batch + if instance_idx + self.bsize > len(trainset): + rng.shuffle(data_indices) + instance_idx = 0 + + batch_indices = data_indices[instance_idx : instance_idx + self.bsize] + batch = [trainset[i] for i in batch_indices] + instance_idx += self.bsize + + # We'll generate (program, model) pairs for the trajectory sampling. + # Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0]. + models = prepare_models_for_resampling(programs[0], self.num_candidates) + top_programs = top_k_plus_baseline(self.num_candidates) + + exec_pairs = [] + predictor2name = {} + + # For each model, for each example, pick a program from the pool via softmax + for model in models: + for example in batch: + chosen_prog_idx = softmax_sample(rng, top_programs, self.temperature_for_sampling) + candidate_system = programs[chosen_prog_idx].deepcopy() + candidate_system.set_lm(model) + + for name, predictor in candidate_system.named_predictors(): + predictor2name[id(predictor)] = name + + # Use the special wrap that includes the 'example' in the output + wrapped_candidate_system = wrap_program(candidate_system, self.metric) + exec_pairs.append((wrapped_candidate_system, example)) + + # STEP 2: Execute + logger.info(f"Sampling program trajectories on {self.bsize} examples x {self.num_candidates} samples.") + outputs = run_parallel(exec_pairs) + assert len(outputs) == len(exec_pairs) == self.bsize * self.num_candidates + + # STEP 3: Sort the training buckets by (max-to-min gap, max score, and max-to-avg gap). + buckets = [] + largest_max_to_avg_gap = float("-inf") + batch_10th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 10) + batch_90th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 90) + + # We'll chunk `outputs` by example index, each chunk has length = num_candidates + for idx, example in enumerate(batch): + # gather all results for this example + bucket = [outputs[i] for i in range(idx, len(outputs), self.bsize)] + bucket.sort(key=lambda x: x["score"], reverse=True) + + max_score = float(bucket[0]["score"]) + min_score = float(bucket[-1]["score"]) + avg_score = sum(x["score"] for x in bucket) / len(bucket) + max_to_min_gap = max_score - min_score + max_to_avg_gap = max_score - avg_score + if max_to_avg_gap > largest_max_to_avg_gap: + largest_max_to_avg_gap = max_to_avg_gap + + buckets.append((bucket, (max_to_min_gap, max_score, max_to_avg_gap))) + + # sort the buckets + buckets.sort(key=lambda x: x[1], reverse=True) + + # Baseline for the batch is just the average of all runs + all_scores_in_this_batch = [o["score"] for o in outputs] + baseline_score = sum(all_scores_in_this_batch) / len(all_scores_in_this_batch) + logger.info(f"Batch {batch_idx+1}: Baseline mini-batch score: {baseline_score}\n") + + # STEP 4: Build new candidate programs by applying a strategy to some top buckets. + system_candidates = [] + for bucket_idx, (bucket, bucket_stats) in enumerate(buckets): + max_to_min_gap, max_score, max_to_avg_gap = bucket_stats + logger.info( + f"Batch {batch_idx+1}: Processing bucket #{bucket_idx+1}, with max score {max_score}, " + f"max-to-min gap {max_to_min_gap}, and max-to-avg gap {max_to_avg_gap}." + ) + + # pick source program + src_prog_idx = softmax_sample( + rng, top_k_plus_baseline(self.num_candidates), self.temperature_for_candidates + ) + system_candidate = programs[src_prog_idx].deepcopy() + + # Drop some demos from each predictor + name2predictor = {} + num_demos_list = [] + + for name, predictor in system_candidate.named_predictors(): + name2predictor[name] = predictor + num_demos_list.append(len(predictor.demos)) + + num_demos = max(num_demos_list) if num_demos_list else 0 + num_demos_to_drop = max(rng_np.poisson(num_demos / self.max_demos), int(num_demos >= self.max_demos)) + num_demos_to_drop = min(num_demos_to_drop, num_demos) + demos_to_drop = [rng.randrange(num_demos) for _ in range(num_demos_to_drop)] + + for name, predictor in name2predictor.items(): + predictor.demos = [demo for idxd, demo in enumerate(predictor.demos) if idxd not in demos_to_drop] + + # Pick a strategy + strategy = rng.choice(self.strategies) + logger.info( + f"Batch {batch_idx+1}: Invoking strategy: {strategy.__name__}" + + (f", having dropped {num_demos_to_drop} demos per predictor" if num_demos_to_drop else "") + ) + + try: + strategy( + bucket, + system_candidate, + predictor2name=predictor2name, + name2predictor=name2predictor, + batch_10p_score=batch_10th_percentile_score, + batch_90p_score=batch_90th_percentile_score, + ) + except Exception as e: + logger.error(f"Strategy failed with error: {e}") + continue + + system_candidates.append(system_candidate) + logger.info("\n") + + if len(system_candidates) >= self.num_candidates + 1: + break + + # STEP 5: Evaluate these new system_candidates on the same mini-batch + logger.info(f"Batch {batch_idx+1}: Evaluating {len(system_candidates)} programs on {self.bsize} examples.") + + exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in system_candidates for ex in batch] + outputs = run_parallel(exec_pairs) + assert len(outputs) == len(exec_pairs) == len(system_candidates) * self.bsize + + # STEP 6: Compute average mini-batch scores for each new candidate + candidate_scores = [] + for idx_cand, cand_sys in enumerate(system_candidates): + start = idx_cand * self.bsize + end = (idx_cand + 1) * self.bsize + sys_scores = [outputs[i]["score"] for i in range(start, end)] + avg_sys_score = sum(sys_scores) / len(sys_scores) + candidate_scores.append(avg_sys_score) + + logger.info( + f"Scores after {batch_idx+1} batches: {candidate_scores}, " + f"Best: {max(candidate_scores) if candidate_scores else 'N/A'}\n" + ) + + # STEP 7: Select the best among these new ones for "winning" record + if candidate_scores: + best_idx_among_candidates = candidate_scores.index(max(candidate_scores)) + best_program = system_candidates[best_idx_among_candidates] + winning_programs.append(best_program) + + # STEP 8: Register all new candidate systems in our global pool + for idx_cand, cand_sys in enumerate(system_candidates): + start = idx_cand * self.bsize + end = (idx_cand + 1) * self.bsize + sys_scores = [outputs[i]["score"] for i in range(start, end)] + register_new_program(cand_sys, sys_scores) + + M = len(winning_programs) - 1 + N = self.num_candidates + 1 + if M < 1: + # Only one or zero winning programs + program_idxs = [0] * N + else: + program_idxs = [round(i * M / (N - 1)) for i in range(N)] + program_idxs = list(dict.fromkeys(program_idxs)) + + candidate_programs = [winning_programs[i] for i in program_idxs] + logger.info(f"VALIDATION: Evaluating {len(candidate_programs)} programs on the full trainset.") + exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in candidate_programs for ex in trainset] + outputs = run_parallel(exec_pairs) + + scores = [] + for idx_prog, prog in enumerate(candidate_programs): + start = idx_prog * len(trainset) + end = (idx_prog + 1) * len(trainset) + sys_scores = [outputs[i]["score"] for i in range(start, end)] + avg_score = sum(sys_scores) / len(sys_scores) if sys_scores else 0.0 + scores.append(avg_score) + + best_idx = scores.index(max(scores)) if scores else 0 + best_program = candidate_programs[best_idx] + logger.info( + f"Final trainset scores: {scores}, Best: {max(scores) if scores else 'N/A'} " + f"(at index {best_idx if scores else 'N/A'})\n\n\n" + ) + + # FIXME: Attach all program candidates in decreasing average score to the best program. + best_program.candidate_programs = candidate_programs + best_program.winning_programs = winning_programs + + return best_program diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py new file mode 100644 index 0000000000..0aea7b5d33 --- /dev/null +++ b/dspy/teleprompt/simba_utils.py @@ -0,0 +1,212 @@ +import dspy +import ujson +import inspect +import logging +import textwrap + +from dspy.adapters.chat_adapter import enumerate_fields +from dspy.signatures import InputField, OutputField +from typing import Callable + +logger = logging.getLogger(__name__) + + +def prepare_models_for_resampling(program: dspy.Module, n: int): + lm = program.get_lm() or dspy.settings.lm + temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / n) for i in range(n)] + temps = list(dict.fromkeys(temps))[:n] + return [lm.copy(temperature=t) for t in temps] + + +def wrap_program(program: dspy.Module, metric: Callable): + def wrapped_program(example): + with dspy.context(trace=[]): + prediction, trace, score = None, None, 0.0 + try: + prediction = program(**example.inputs()) + except Exception as e: + print(e) + trace = dspy.settings.trace.copy() + + try: + score = metric(example, prediction) + except Exception as e: + print(e) + + # Include the `example` in the output for subsequent usage in buckets/strategies. + return { + "prediction": prediction, + "trace": trace, + "score": score, + "example": example + } + + return wrapped_program + + + +def append_a_demo(demo_input_field_maxlen): + def append_a_demo_(bucket, system, **kwargs): + predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] + + trace = bucket[0]["trace"] + name2demo = {} + + for step in trace: + predictor, _inputs, _outputs = step + + for k, v in _inputs.items(): + if demo_input_field_maxlen and len(str(v)) > demo_input_field_maxlen: + _inputs[k] = f"{str(v)[:demo_input_field_maxlen]}\n\t\t... " + + demo = dspy.Example(augmented=True, **_inputs, **_outputs) + name = predictor2name[id(predictor)] + name2demo[name] = demo # keep the last demo for each predictor + + for name, demo in name2demo.items(): + predictor = name2predictor[name] + predictor.demos.append(demo) + + logger.info(f"Added {len(name2demo)} demos (one each) across all predictors.") + return True + + return append_a_demo_ + + +def append_a_rule(bucket, system, **kwargs): + predictor2name = kwargs["predictor2name"] + batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] + + module_names = [name for name, _ in system.named_predictors()] + good, bad = bucket[0], bucket[-1] + example = good["example"] + + if good["score"] < batch_10p_score or bad["score"] > batch_90p_score: + logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile " + f"*or* bad score {bad['score']} is above the 90th percentile.") + return False + + if good["score"] <= bad["score"]: + if good["score"] > batch_90p_score: + bad["trace"] = [] + bad["score"] = "N/A" + bad["prediction"] = {"N/A": "Prediction not available"} + else: + good["trace"] = [] + good["score"] = "N/A" + good["prediction"] = {"N/A": "Prediction not available"} + + better_trajectory = [ + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) + for p, i, o in good["trace"] + ] + worse_trajectory = [ + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) + for p, i, o in bad["trace"] + ] + + kwargs = dict( + program_code=inspect.getsource(system.__class__), + modules_defn=inspect_modules(system), + program_inputs={**example.inputs()}, + oracle_metadata={**example.labels()}, + better_program_trajectory=better_trajectory, + better_program_outputs=dict(good["prediction"]), + worse_program_trajectory=worse_trajectory, + worse_program_outputs=dict(bad["prediction"] or {}), + worse_reward_value=bad["score"], + better_reward_value=good["score"], + module_names=module_names, + ) + + kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2) + for k, v in kwargs.items()} + advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice + + for name, predictor in system.named_predictors(): + if name in advice: + logger.info(f"Advice for {name}: {advice[name]}") + instructions = predictor.signature.instructions + "\n\n" + advice[name] + predictor.signature = predictor.signature.with_instructions(instructions) + + return True + +class OfferFeedback(dspy.Signature): + """ + You will be given two trajectories of an LLM-driven program's execution. Your goal is to help the program's modules + build up experience on how to maximize the reward value assigned to the program's outputs if it were to receive + similar inputs in the future. + + The module won't see its own history. It will rely on your advice balancing being concrete and being generalizable. + + In your advice: + - Avoid boilerplate. Offer advice that would change the module's behavior for the better in the future. + - Ensure that advice offered to a module M is specific to that M's specific sub-task, not the overall program. + - Rely on contrasting the behavior of the worse trajectory against the better trajectory in making recommendations. + - Ensure each unique module name appears exactly once as a key in the advice dictionary. + """ + + program_code: str = InputField(desc="The code of the program that we are analyzing") + modules_defn: str = InputField(desc="The definition of each module in the program, including its I/O") + program_inputs: str = InputField(desc="The inputs to the program that we are analyzing") + oracle_metadata: str = InputField(desc="Any (hidden) metadata about the training set instance we're analyzing") + worse_program_trajectory: str = InputField( + desc="The trajectory of the program's execution, showing each module's I/O" + ) + worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") + worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + better_program_trajectory: str = InputField( + desc="The trajectory of the program's execution, showing each module's I/O" + ) + better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") + better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") + discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") + module_advice: dict[str, str] = OutputField( + desc="For each module, describe very concretely: If the module receives ${description of input or patterns " + "therein}, then it should ${description of content, behavior, or strategies to adopt and/or others to avoid}. " + "Basically, your advice be such that if the module has access to your tip, it would be much more likely to act " + "like the successful trajectory rather than the lower-scoring trajectory." + ) + + +def inspect_modules(program): + separator = "-" * 80 + output = [separator] + + for idx, (name, predictor) in enumerate(program.named_predictors()): + signature = predictor.signature + instructions = textwrap.dedent(signature.instructions) + instructions = ("\n" + "\t" * 2).join([""] + instructions.splitlines()) + + output.append(f"Module {name}") + output.append("\n\tInput Fields:") + output.append(("\n" + "\t" * 2).join([""] + enumerate_fields(signature.input_fields).splitlines())) + output.append("\tOutput Fields:") + output.append(("\n" + "\t" * 2).join([""] + enumerate_fields(signature.output_fields).splitlines())) + output.append(f"\tOriginal Instructions: {instructions}") + output.append(separator) + + return "\n".join([o.strip("\n") for o in output]) + + +def recursive_mask(o): + # If the object is already serializable, return it. + try: + ujson.dumps(o) + return o + except TypeError: + pass + + # If it's a dictionary, apply recursively to its values. + if isinstance(o, dict): + return {k: recursive_mask(v) for k, v in o.items()} + # If it's a list, apply recursively. + elif isinstance(o, list): + return [recursive_mask(v) for v in o] + # If it's a tuple, apply recursively. + elif isinstance(o, tuple): + return tuple(recursive_mask(v) for v in o) + # Otherwise, replace it with a placeholder string (or use repr(o)). + else: + return f""