Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions docs/docs/building-blocks/6-optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,54 @@ DSPy programs consist of multiple calls to LMs, stacked together as [DSPy module

Given a metric, DSPy can optimize all of these three with multi-stage optimization algorithms. These can combine gradient descent (for LM weights) and discrete LM-driven optimization, i.e. for crafting/updating instructions and for creating/validating demonstrations. DSPy Demonstrations are like few-shot examples, but they're far more powerful. They can be created from scratch, given your program, and their creation and selection can be optimized in many effective ways.

In many cases, we found that compiling leads to better prompts than humans write. Not because DSPy optimizers are more creative than humans, but simply because they can try more things, much more systematically, and tune the metrics directly.
In many cases, we found that compiling leads to better prompts than human writing. Not because DSPy optimizers are more creative than humans, but simply because they can try more things, much more systematically, and tune the metrics directly.


## What DSPy Optimizers are currently available?

<!-- The following diagram was generated by: -->
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give yourself credit! :)

<!-- 1. Running symilar on the teleprompter module to extract the python hierarchy as a Graphviz dot file -->
<!-- 2. Hand-editing the resulting dot file to remove classes that are not teleprompters/optimizers (e.g., classes for data structures manipulated by optimizers). -->
<!-- 3. Using dot to compile the `.dot` file into a PNG -->
<!-- Robert Goldman [2024/05/11:rpg] -->

[Subclasses of Teleprompter](figures/teleprompter-classes.png)

All of these can be accessed via `from dspy.teleprompt import *`.

#### Automatic Few-Shot Learning

1. **`LabeledFewShot`**: Simply constructs few-shot examples from provided labeled Q/A pairs.
These optimizers extend the signature by automatically generating and including **optimized** examples within the prompt sent to the model, implementing few-shot learning.

1. **`LabeledFewShot`**: Simply constructs few-shot examples (demos) from provided labeled input and output data points. Requires `k` (number of examples for the prompt) and `trainset` to randomly select `k` examples from.

2. **`BootstrapFewShot`**: Uses a `teacher` module (which defaults to your program) to generate complete demonstrations for every stage of your program, along with labeled examples in `trainset`. Parameters include `max_labeled_demos` (the number of demonstrations randomly selected from the `trainset`) and `max_bootstrapped_demos` (the number of additional examples generated by the `teacher`). The bootstrapping process employs the metric to validate demonstrations, including only those that pass the metric in the "compiled" prompt. Advanced: Supports using a `teacher` program that is a *different* DSPy program that has compatible structure, for harder tasks.

2. **`BootstrapFewShot`**: Uses your program to self-generate complete demonstrations for every stage of your program. Will simply use the generated demonstrations (if they pass the metric) without any further optimization. Advanced: Supports using a teacher program (a different DSPy program that has compatible structure) and a teacher LM, for harder tasks.
3. **`BootstrapFewShotWithRandomSearch`**: Applies `BootstrapFewShot` several times with random search over generated demonstrations, and selects the best program over the optimization. Parameters mirror those of `BootstrapFewShot`, with the addition of `num_candidate_programs`, which specifies the number of random programs evaluated over the optimization, including candidates of the uncompiled program, `LabeledFewShot` optimized program, `BootstrapFewShot` compiled program with unshuffled examples and `num_candidate_programs` of `BootstrapFewShot` compiled programs with randomized example sets.

3. **`BootstrapFewShotWithRandomSearch`**: Applies `BootstrapFewShot` several times with random search over generated demonstrations, and selects the best program.
4. **`BootstrapFewShotWithOptuna`**: Applies `BootstrapFewShot` with Optuna optimization across demonstration sets, running trials to maximize evaluation metrics and selecting the best demonstrations.

4. **`BootstrapFewShotWithOptuna`**: Applies `BootstrapFewShot` through Optuna hyperparameter optimization across demonstration sets, running trials to maximize evaluation metrics.
5. **`KNNFewShot`**. Selects demonstrations through k-Nearest Neighbors algorithm to pick a diverse set of examples from different clusters. Vectorizes the examples, and then clusters them, using cluster centers with `BootstrapFewShot` for bootstrapping/selection process. This will be useful when there's a lot of data over random spaces: using KNN helps optimize the `trainset` for `BootstrapFewShot`. See [this notebook](https://github.com/stanfordnlp/dspy/blob/main/examples/knn.ipynb) for an example.


#### Automatic Instruction Optimization

4. **`COPRO`**: Generates and refines new instructions for each step, and optimizes them with coordinate ascent.
These optimizers produce optimal instructions for the prompt and, in the case of MIPRO also optimize the set of few-shot demonstrations.

5. **`MIPRO`**: Generates instructions and few-shot examples in each step. The instruction generation is data-aware and demonstration-aware. Uses Bayesian Optimization to effectively search over the space of generation instructions/demonstrations across your modules.
6. **`COPRO`**: Generates and refines new instructions for each step, and optimizes them with coordinate ascent (hill-climbing using the metric function and the `trainset`). Parameters include `depth` which is the number of iterations of prompt improvement the optimizer runs over.

7. **`MIPRO`**: Generates instructions *and* few-shot examples in each step. The instruction generation is data-aware and demonstration-aware. Uses Bayesian Optimization to effectively search over the space of generation instructions/demonstrations across your modules.


#### Automatic Finetuning

This optimizer is used to fine-tune the underlying LLM(s).

6. **`BootstrapFinetune`**: Distills a prompt-based DSPy program into weight updates (for smaller LMs). The output is a DSPy program that has the same steps, but where each step is conducted by a finetuned model instead of a prompted LM.


#### Program Transformations

7. **`KNNFewShot`**. Selects demonstrations through k-Nearest Neighbors algorithm integrating `BootstrapFewShot` for bootstrapping/selection process.

8. **`Ensemble`**: Ensembles a set of DSPy programs and either uses the full set or randomly samples a subset into a single program.


Expand Down Expand Up @@ -90,7 +104,7 @@ from dspy.teleprompt import BootstrapFewShotWithRandomSearch

# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 8-shot examples of your program's steps.
# The optimizer will repeat this 10 times (plus some initial attempts) before selecting its best attempt on the devset.
config = dict(max_bootstrapped_demos=3, max_labeled_demos=3, num_candidate_programs=10, num_threads=4)
config = dict(max_bootstrapped_demos=4, max_labeled_demos=4, num_candidate_programs=10, num_threads=4)

teleprompter = BootstrapFewShotWithRandomSearch(metric=YOUR_METRIC_HERE, **config)
optimized_program = teleprompter.compile(YOUR_PROGRAM_HERE, trainset=YOUR_TRAINSET_HERE)
Expand All @@ -115,4 +129,4 @@ To load a program from a file, you can instantiate an object from that class and
```python
loaded_program = YOUR_PROGRAM_CLASS()
loaded_program.load(path=YOUR_SAVE_PATH)
```
```
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 32 additions & 6 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
import threading
from typing import Dict, Optional

import tqdm

Expand All @@ -15,7 +16,8 @@

# TODO: Switch here from dsp.Example to dspy.Example. Right now, it's okay because it's internal only (predictors).
# NOTE: Notice the places where we don't shuffle examples. I do like that this one doesn't shuffle.
# Other ones that consider options may want to use both unshuffled and then shuffle a few times, when considering candidates.
# Other ones that consider options may want to use both unshuffled and then shuffle a few times, when
# considering candidates.

# TODO: the max_rounds via branch_idx to get past the cache, not just temperature.
# In principle, we can also sample multiple outputs from the final generation step
Expand All @@ -25,25 +27,47 @@
# won't hurt our "best effort" guarantees.)

# TODO: When this bootstraps for another teleprompter like finetune, we want all demos we gather.
# But when it's for direct use we may want to sample ONE demo per predictor--example pair. This is important for "multi-use" modules.
# But when it's for direct use we may want to sample ONE demo per predictor--example pair.
# This is important for "multi-use" modules.

# TODO: Add baselines=[...]


class BootstrapFewShot(Teleprompter):
def __init__(
self,
metric=None,
metric_threshold=None,
teacher_settings={},
teacher_settings: Optional[Dict]=None,
max_bootstrapped_demos=4,
max_labeled_demos=16,
max_rounds=1,
max_errors=5,
):
"""
A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt.
These demos come from a combination of labeled examples in the training set, and bootstrapped demos.

Parameters
----------
metric: Callable
A function that compares an expected value and predicted value, outputting the result of that comparison.
metric_threshold: optional float, default `None`
If the metric yields a numerical value, then check it against this threshold when
deciding whether or not to accept a bootstrap example.
teacher_settings: dict, optional
Settings for the `teacher` model.
max_bootstrapped_demos: int, default 4
Maximum number of bootstrapped demonstrations to include
max_labeled_demos: int, default 16
Maximum number of labeled demonstrations to include.
max_rounds: int, default 1
Number of iterations to attempt generating the required bootstrap examples. If unsuccessful after `max_rounds`, the program ends.
max_errors: int, default 5
Maximum number of errors until program ends.
"""
self.metric = metric
self.metric_threshold = metric_threshold
self.teacher_settings = teacher_settings
self.teacher_settings = {} if teacher_settings is None else teacher_settings

self.max_bootstrapped_demos = max_bootstrapped_demos
self.max_labeled_demos = max_labeled_demos
Expand Down Expand Up @@ -91,7 +115,9 @@ def _prepare_predictor_mappings(self):
assert name1 == name2, "Student and teacher must have the same program structure."
assert predictor1.signature.equals(
predictor2.signature,
), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
), (f"Student and teacher must have the same signatures. "
f"{type(predictor1.signature)} != {type(predictor2.signature)}"
)
assert id(predictor1) != id(predictor2), "Student and teacher must be different objects."

name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
Expand Down
12 changes: 11 additions & 1 deletion dspy/teleprompt/copro_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,17 @@ def _set_signature(self, predictor, updated_signature):
predictor.signature = updated_signature

def compile(self, student, *, trainset, eval_kwargs):
"""student is a program that needs to be optimized, note that it may be zero-shot or already pre-optimized for demos != []"""
"""
optimizes `signature` of `student` program - note that it may be zero-shot or already pre-optimized (demos already chosen - `demos != []`)

parameters:
student: program to optimize and left modified.
trainset: iterable of `Example`s
eval_kwargs: optional, dict
Additional keywords to go into `Evaluate` for the metric.

Returns optimized version of `student`.
"""
module = student.deepcopy()
evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs)
total_calls = 0
Expand Down
8 changes: 5 additions & 3 deletions dspy/teleprompt/mipro_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@
"""
USAGE SUGGESTIONS:

The following code can be used to compile a optimized signature teleprompter using the MIPRO, and evaluate it on an end task:
The following code can be used to compile a optimized signature teleprompter using MIPRO, and evaluate it on an end task:

``` python
from dspy.teleprompt import MIPRO

teleprompter = MIPRO(prompt_model=prompt_model, task_model=task_model, metric=metric, num_candidates=10, init_temperature=1.0)
kwargs = dict(num_threads=NUM_THREADS, display_progress=True, display_table=0)
compiled_prompt_opt = teleprompter.compile(program, trainset=trainset[:TRAIN_NUM], num_trials=100, max_bootstrapped_demos=3, max_labeled_demos=5, eval_kwargs=kwargs)
eval_score = evaluate(compiled_prompt_opt, devset=evalset[:EVAL_NUM], **kwargs)
```

Note that this teleprompter takes in the following parameters:

* prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)).
* task_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)).
* prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (i.e., dspy.settings.configure(lm=task_model)).
* task_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (i.e., dspy.settings.configure(lm=task_model)).
* metric: The task metric used for optimization.
* num_candidates: The number of new prompts and sets of fewshot examples to generate and evaluate. Default=10.
* init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.0.
Expand Down
2 changes: 1 addition & 1 deletion dspy/teleprompt/teleprompt_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def objective(self, trial):
display_table=False,
display_progress=True,
)
score, _ = evaluate(program2, return_all_scores=True)
score = evaluate(program2, return_all_scores=False)
trial.set_user_attr("program", program2)
return score

Expand Down