diff --git a/docs/docs/cheatsheet.md b/docs/docs/cheatsheet.md index d6fef8b135..630b9a563f 100644 --- a/docs/docs/cheatsheet.md +++ b/docs/docs/cheatsheet.md @@ -392,6 +392,7 @@ compiled_program_optimized_bayesian_signature = teleprompter.compile(your_dspy_p Note: detailed documentation can be found [here](/docs/deep-dive/optimizers/miprov2.md). `MIPROv2` is the latest extension of `MIPRO` which includes updates such as (1) improvements to instruction proposal and (2) more efficient search with minibatching. #### Optimizing with MIPROv2 +This shows how to perform an easy out-of-the box run with `auto=light`, which configures many hyperparameters for you and performs a light optimization run. You can alternatively set `auto=medium` or `auto=heavy` to perform longer optimization runs. The more detailed `MIPROv2` documentation [here](/docs/deep-dive/optimizers/miprov2.md) also provides more information about how to set hyperparameters by hand. ```python # Import the optimizer from dspy.teleprompt import MIPROv2 @@ -399,10 +400,7 @@ from dspy.teleprompt import MIPROv2 # Initialize optimizer teleprompter = MIPROv2( metric=gsm8k_metric, - num_candidates=7, - init_temperature=0.5, - verbose=False, - num_threads=4, + auto="light", # Can choose between light, medium, and heavy optimization runs ) # Optimize program @@ -412,10 +410,6 @@ optimized_program = teleprompter.compile( trainset=trainset, max_bootstrapped_demos=3, max_labeled_demos=4, - num_trials=15, - minibatch_size=25, - minibatch_full_eval_steps=10, - minibatch=True, requires_permission_to_run=False, ) @@ -429,35 +423,31 @@ evaluate(optimized_program, devset=devset[:]) #### Optimizing instructions only with MIPROv2 (0-Shot) ```python -# Import optimizer +# Import the optimizer from dspy.teleprompt import MIPROv2 -# Initialize optimizer +# Initialize optimizer teleprompter = MIPROv2( metric=gsm8k_metric, - num_candidates=15, - init_temperature=0.5, - verbose=False, - num_threads=4, + auto="light", # Can choose between light, medium, and heavy optimization runs ) -# Perform optimization -print(f"Optimizing program with MIPRO (0-Shot)...") -zeroshot_optimized_program = teleprompter.compile( +# Optimize program +print(f"Optimizing program with MIPRO...") +optimized_program = teleprompter.compile( program.deepcopy(), trainset=trainset, - max_bootstrapped_demos=0, # setting demos to 0 for 0-shot optimization + max_bootstrapped_demos=0, max_labeled_demos=0, - num_trials=15, - minibatch=False, requires_permission_to_run=False, ) +# Save optimize program for future use +optimized_program.save(f"mipro_optimized") -zeroshot_optimized_program.save(f"mipro_0shot_optimized") - -print(f"Evaluate optimized program...") -evaluate(zeroshot_optimized_program, devset=devset[:]) +# Evaluate optimized program +print(f"Evluate optimized program...") +evaluate(optimized_program, devset=devset[:]) ``` ### Signature Optimizer with Types diff --git a/docs/docs/deep-dive/optimizers/miprov2.md b/docs/docs/deep-dive/optimizers/miprov2.md index 1ef388ea1f..759f003f12 100644 --- a/docs/docs/deep-dive/optimizers/miprov2.md +++ b/docs/docs/deep-dive/optimizers/miprov2.md @@ -6,6 +6,7 @@ - [Setting up a Sample Pipeline](#setting-up-a-sample-pipeline) - [Optimizing with MIPROv2](#optimizing-with-miprov2) - [Optimizing instructions only with MIPROv2 (0-Shot)](#optimizing-instructions-only-with-miprov2-0-shot) + - [Optimizing with MIPROv2 (Advanced)](#optimizing-with-miprov2-advanced) 3. [Parameters](#parameters) - [Initialization Parameters](#initialization-parameters) - [Compile Method Specific Parameters](#compile-method-specific-parameters) @@ -69,7 +70,7 @@ evaluate(program, devset=devset[:]) Now we have the baseline pipeline ready to use, so let's try using the `MIPROv2` optimizer to improve our pipeline's performance! ### Optimizing with `MIPROv2` -Here we show how to import, initialize, and compile our program with optimized few-shot examples and instructions using `MIPROv2`. +To get started with `MIPROv2`, we'd recommend using the `auto` flag, starting with a `light` optimization run. This will set up hyperparameters for you to do a light optimization run on your program. ```python # Import the optimizer @@ -78,10 +79,7 @@ from dspy.teleprompt import MIPROv2 # Initialize optimizer teleprompter = MIPROv2( metric=gsm8k_metric, - num_candidates=7, - init_temperature=0.5, - verbose=False, - num_threads=4, + auto="light", # Can choose between light, medium, and heavy optimization runs ) # Optimize program @@ -91,10 +89,6 @@ optimized_program = teleprompter.compile( trainset=trainset, max_bootstrapped_demos=3, max_labeled_demos=4, - num_trials=15, - minibatch_size=25, - minibatch_full_eval_steps=10, - minibatch=True, requires_permission_to_run=False, ) @@ -106,41 +100,72 @@ print(f"Evluate optimized program...") evaluate(optimized_program, devset=devset[:]) ``` -### Optimizing instructions only with `MIPROv2` (0-Shot) +#### Optimizing instructions only with `MIPROv2` (0-Shot) In some cases, we may want to only optimize the instruction, rather than including few-shot examples in the prompt. The code below demonstrates how this can be done using `MIPROv2`. Note that the key difference involves setting `max_labeled_demos` and `max_bootstrapped_demos` to zero. ```python -# Import optimizer +# Import the optimizer from dspy.teleprompt import MIPROv2 -# Initialize optimizer +# Initialize optimizer +teleprompter = MIPROv2( + metric=gsm8k_metric, + auto="light", # Can choose between light, medium, and heavy optimization runs +) + +# Optimize program +print(f"Optimizing zero-shot program with MIPRO...") +zeroshot_optimized_program = teleprompter.compile( + program.deepcopy(), + trainset=trainset, + max_bootstrapped_demos=0, # ZERO FEW-SHOT EXAMPLES + max_labeled_demos=0, # ZERO FEW-SHOT EXAMPLES + requires_permission_to_run=False, +) + +# Save optimize program for future use +zeroshot_optimized_program.save(f"mipro_zeroshot_optimized") + +# Evaluate optimized program +print(f"Evluate optimized program...") +evaluate(zeroshot_optimized_program, devset=devset[:]) +``` + +#### Optimizing with `MIPROv2` (advanced) +Once you've gotten a feel for using `MIPROv2` with `auto` settings, you may want to experiment with setting hyperparameters yourself to get the best results. The code below shows an example of how you can go about this. A full description of each parameter can be found in the section below. + +```python +# Import the optimizer +from dspy.teleprompt import MIPROv2 + +# Initialize optimizer teleprompter = MIPROv2( metric=gsm8k_metric, num_candidates=7, init_temperature=0.5, + max_bootstrapped_demos=3, + max_labeled_demos=4, verbose=False, - num_threads=4, ) -# Perform optimization -print(f"Optimizing program with MIPRO (0-Shot)...") -zeroshot_optimized_program = teleprompter.compile( +# Optimize program +print(f"Optimizing program with MIPRO...") +optimized_program = teleprompter.compile( program.deepcopy(), trainset=trainset, - max_bootstrapped_demos=0, # setting demos to 0 for 0-shot optimization - max_labeled_demos=0, num_trials=15, minibatch_size=25, minibatch_full_eval_steps=10, - minibatch=False, + minibatch=True, requires_permission_to_run=False, ) +# Save optimize program for future use +optimized_program.save(f"mipro_optimized") -zeroshot_optimized_program.save(f"mipro_0shot_optimized") - -print(f"Evaluate optimized program...") -evaluate(zeroshot_optimized_program, devset=devset[:]) +# Evaluate optimized program +print(f"Evluate optimized program...") +evaluate(optimized_program, devset=devset[:]) ``` ## Parameters @@ -149,32 +174,36 @@ evaluate(zeroshot_optimized_program, devset=devset[:]) | Parameter | Type | Default | Description | |----------------------|--------------|------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------| -| `metric` | `dspy.metric` | N/A - Required | The evaluation metric used to optimize the task model. | +| `metric` | `dspy.metric` | **Required** | The evaluation metric used to optimize the task model. | | `prompt_model` | `dspy.LM` | LM specified in `dspy.settings` | Model used for prompt generation. | | `task_model` | `dspy.LM` | LM specified in `dspy.settings` | Model used for task execution. | +| `auto` | `Optional[str]` | None | If set to `light`, `medium`, or `heavy`, this will automatically configure the following hyperparameters: `num_candidates`, `num_trials`, `minibatch`, and will also cap the size of `valset` up to 100, 300, and 1000 for `light`, `medium`, and `heavy` runs respectively. | | `num_candidates` | `int` | `10` | Number of candidate instructions & few-shot examples to generate and evaluate for each predictor. If `num_candidates=10`, this means for a 2 module LM program we'll be optimizing over 10 candidates x 2 modules x 2 variables (few-shot ex. and instructions for each module)= 40 total variables. Therfore, if we increase `num_candidates`, we will probably want to increase `num_trials` as well (see Compile parameters). | | `num_threads` | `int` | `6` | Threads to use for evaluation. | | `max_errors` | `int` | `10` | Maximum errors during an evaluation run that can be made before throwing an Exception. | | `teacher_settings` | `dict` | `{}` | Settings to use for the teacher model that bootstraps few-shot examples. An example dict would be `{lm=}`. If your LM program with your default model is struggling to bootstrap any examples, it could be worth using a more powerful teacher model for bootstrapping. | +| `max_bootstrapped_demos` | `int` | `4` | Maximum number of bootstrapped demonstrations to generate and include in the prompt. | +| `max_labeled_demos` | `int` | `16` | Maximum number of labeled demonstrations to generate and include in the prompt. Note that these differ from bootstrapped examples because they are just inputs & outputs sampled directly from the training set and do not have bootstrapped intermediate steps. | `init_temperature` | `float` | `1.0` | The initial temperature for prompt generation, influencing creativity. | | `verbose` | `bool` | `False` | Enables printing intermediate steps and information. | | `track_stats` | `bool` | `True` | Logs relevant information through the optimization process if set to True. | | `metric_threshold` | `float` | `None` | A metric threshold is used if we only want to keep bootstrapped few-shot examples that exceed some threshold of performance. | +| `seed` | `int` | `9` | Seed for reproducibility. | ### Compile Parameters | Parameter | Type | Default | Description | |----------------------------|----------|---------|----------------------------------------------------------------------------------------------------------| -| `student` | `dspy.Module` | N/A (Required) | The base program to optimize. | -| `trainset` | `List[dspy.Example]` | N/A (Required) | Training dataset which is used to bootstrap few-shot examples and instructions. If a separate `valset` is not specified, 80% of this training set will also be used as a validation set for evaluating new candidate prompts. | +| `student` | `dspy.Module` | **Required** | The base program to optimize. | +| `trainset` | `List[dspy.Example]` | **Required** | Training dataset which is used to bootstrap few-shot examples and instructions. If a separate `valset` is not specified, 80% of this training set will also be used as a validation set for evaluating new candidate prompts. | | `valset` | `List[dspy.Example]` | Defaults to 80% of trainset | Dataset which is used to evaluate candidate prompts. We recommend using somewhere between 50-500 examples for optimization. | | `num_trials` | `int` | `30` | Number of optimization trials to run. When `minibatch` is set to `True`, this represents the number of minibatch trials that will be run on batches of size `minibatch_size`. When minibatch is set to `False`, each trial uses a full evaluation on the training set. In both cases, we recommend setting `num_trials` to a *minimum* of .75 x # modules in program x # variables per module (2 if few-shot examples & instructions will both be optimized, 1 in the 0-shot case). | | `minibatch` | `bool` | `True` | Flag to enable evaluating over minibatches of data (instead of the full validation set) for evaluation each trial. | | `minibatch_size` | `int` | `25.0` | Size of minibatches for evaluations. | | `minibatch_full_eval_steps` | `int` | `10` | When minibatching is enabled, a full evaluation on the validation set will be carried out every `minibatch_full_eval_steps` on the top averaging set of prompts (according to their average score on the minibatch trials). -| `max_bootstrapped_demos` | `int` | `4` | Maximum number of bootstrapped demonstrations to generate and include in the prompt. | -| `max_labeled_demos` | `int` | `16` | Maximum number of labeled demonstrations to generate and include in the prompt. Note that these differ from bootstrapped examples because they are just inputs & outputs sampled directly from the training set and do not have bootstrapped intermediate steps. | -| `seed` | `int` | `9` | Seed for reproducibility. | | +| `max_bootstrapped_demos` | `Optional[int]` | Defaults to `init` value. | Maximum number of bootstrapped demonstrations to generate and include in the prompt. | +| `max_labeled_demos` | `Optional[int]` | Defaults to `init` value. | Maximum number of labeled demonstrations to generate and include in the prompt. Note that these differ from bootstrapped examples because they are just inputs & outputs sampled directly from the training set and do not have bootstrapped intermediate steps. | +| `seed` | `Optional[int]` | Defaults to `init` value. | Seed for reproducibility. | | | `program_aware_proposer` | `bool` | `True` | Flag to enable summarizing a reflexive view of the code for your LM program. | | `data_aware_proposer` | `bool` | `True` | Flag to enable summarizing your training dataset. | | `view_data_batch_size` | `int` | `10` | Number of data examples to look at a time when generating the summary. | @@ -189,7 +218,7 @@ At a high level, `MIPROv2` works by creating both few-shot examples and new inst These steps are broken down in more detail below: 1) **Bootstrap Few-Shot Examples**: The same bootstrapping technique used in `BootstrapFewshotWithRandomSearch` is used to create few-shot examples. This works by randomly sampling examples from your training set, which are then run through your LM program. If the output from the program is correct for this example, it is kept as a valid few-shot example candidate. Otherwise, we try another example until we've curated the specified amount of few-shot example candidates. This step creates `num_candidates` sets of `max_bootstrapped_demos` bootstrapped examples and `max_labeled_demos` basic examples sampled from the training set. 2) **Propose Instruction Candidates**. Next, we propose instruction candidates for each predictor in the program. This is done using another LM program as a proposer, which bootstraps & summarizes relevant information about the task to generate high quality instructions. Specifically, the instruction proposer includes (1) a generated summary of properties of the training dataset, (2) a generated summary of your LM program's code and the specific predictor that an instruction is being generated for, (3) the previously bootstrapped few-shot examples to show reference inputs / outputs for a given predictor and (4) a randomly sampled tip for generation (i.e. "be creative", "be concise", etc.) to help explore the feature space of potential instructions. -3. **Find an Optimized Combination of Few-Shot Examples & Instructions**. Finally, now that we've created these few-shot examples and instructions, we use Bayesian Optimization to choose which set of these would work best for each predictor in our program. This works by running a series of `num_batches` trials, where a new set of prompts are evaluated over our validation set at each trial. This helps the Bayesian Optimizer learn which combination of prompts work best over time. If `minibatch` is set to `True` (which it is by default), then the new set of prompts are only evaluated on a minibatch of size `minibatch_size` at each trial which generally allows for more efficient exploration / exploitation. The best averaging set of prompts is then evalauted on the full validation set every `minibatch_full_eval_steps` get a less noisey performance benchmark. At the end of the optimization process, the LM program with the set of prompts that performed best on the full validation set is returned. +3. **Find an Optimized Combination of Few-Shot Examples & Instructions**. Finally, now that we've created these few-shot examples and instructions, we use Bayesian Optimization to choose which set of these would work best for each predictor in our program. This works by running a series of `num_trials` trials, where a new set of prompts are evaluated over our validation set at each trial. This helps the Bayesian Optimizer learn which combination of prompts work best over time. If `minibatch` is set to `True` (which it is by default), then the new set of prompts are only evaluated on a minibatch of size `minibatch_size` at each trial which generally allows for more efficient exploration / exploitation. The best averaging set of prompts is then evalauted on the full validation set every `minibatch_full_eval_steps` get a less noisey performance benchmark. At the end of the optimization process, the LM program with the set of prompts that performed best on the full validation set is returned. For those interested in more details, more information on `MIPROv2` along with a study on `MIPROv2` compared with other DSPy optimizers can be found in [this paper](https://arxiv.org/abs/2406.11695). \ No newline at end of file diff --git a/dspy/propose/grounded_proposer.py b/dspy/propose/grounded_proposer.py index d1d6192bbc..5b70e7b21f 100644 --- a/dspy/propose/grounded_proposer.py +++ b/dspy/propose/grounded_proposer.py @@ -187,6 +187,7 @@ def forward( # Summarize the program program_description = "Not available" module_code = "Not provided" + module_description = "Not provided" if self.program_aware: try: program_description = strip_prefix( @@ -209,18 +210,18 @@ def forward( outputs.append(field_name) module_code = f"{program.predictors()[pred_i].__class__.__name__}({', '.join(inputs)}) -> {', '.join(outputs)}" + + module_description = self.describe_module( + program_code=self.program_code_string, + program_description=program_description, + program_example=task_demos, + module=module_code, + max_depth=10, + ).module_description except: if self.verbose: print("Error getting program description. Running without program aware proposer.") self.program_aware = False - module_description = self.describe_module( - program_code=self.program_code_string, - program_description=program_description, - program_example=task_demos, - module=module_code, - max_depth=10, - ).module_description - # Generate an instruction for our chosen module if self.verbose: print(f"task_demos {task_demos}") instruct = self.generate_module_instruction( @@ -258,6 +259,7 @@ def __init__( set_tip_randomly=True, set_history_randomly=True, verbose=False, + rng=None ): super().__init__() self.program_aware = program_aware @@ -268,22 +270,30 @@ def __init__( self.set_tip_randomly=set_tip_randomly self.set_history_randomly=set_history_randomly self.verbose = verbose + self.rng = rng or random self.prompt_model = get_prompt_model(prompt_model) + + self.program_code_string = None if self.program_aware: try: self.program_code_string = get_dspy_source_code(program) if self.verbose: print("SOURCE CODE:",self.program_code_string) except Exception as e: print(f"Error getting source code: {e}.\n\nRunning without program aware proposer.") - self.program_code_string = None self.program_aware = False - else: - self.program_code_string = None - self.data_summary = create_dataset_summary( - trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model, - ) - if self.verbose: print(f"DATA SUMMARY: {self.data_summary}") + + self.data_summary = None + if self.use_dataset_summary: + try: + self.data_summary = create_dataset_summary( + trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model, + ) + if self.verbose: print(f"DATA SUMMARY: {self.data_summary}") + except Exception as e: + print(f"Error getting data summary: {e}.\n\nRunning without data aware proposer.") + self.use_dataset_summary = False + print("") def propose_instructions_for_program( self, @@ -301,7 +311,7 @@ def propose_instructions_for_program( if self.set_history_randomly: # Randomly select whether or not we're using instruction history - use_history = random.random() < 0.5 + use_history = self.rng.random() < 0.5 self.use_instruct_history = use_history if self.verbose: print(f"Use history T/F: {self.use_instruct_history}") @@ -319,7 +329,7 @@ def propose_instructions_for_program( if self.set_tip_randomly: if self.verbose: print("Using a randomly generated configuration for our grounded proposer.") # Randomly select the tip - selected_tip_key = random.choice(list(TIPS.keys())) + selected_tip_key = self.rng.choice(list(TIPS.keys())) selected_tip = TIPS[selected_tip_key] self.use_tip = bool( selected_tip, diff --git a/dspy/teleprompt/mipro_optimizer_v2.py b/dspy/teleprompt/mipro_optimizer_v2.py index ab758f7e62..6603a259b2 100644 --- a/dspy/teleprompt/mipro_optimizer_v2.py +++ b/dspy/teleprompt/mipro_optimizer_v2.py @@ -2,6 +2,7 @@ import sys import textwrap from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import optuna @@ -15,75 +16,66 @@ eval_candidate_program, get_program_with_highest_avg_score, get_signature, - get_task_model_history_for_full_example, print_full_program, save_candidate_program, set_signature, + create_minibatch, ) -""" -USAGE SUGGESTIONS: - -The following code can be used to compile a optimized signature teleprompter using MIPRO, and evaluate it on an end task: - -``` python -from dspy.teleprompt import MIPROv2 - -teleprompter = MIPROv2(prompt_model=prompt_model, task_model=task_model, metric=metric, num_candidates=10, init_temperature=1.0) -kwargs = dict(num_threads=NUM_THREADS, display_progress=True, display_table=0) -compiled_prompt_opt = teleprompter.compile(program, trainset=trainset[:TRAIN_NUM], num_trials=100, max_bootstrapped_demos=3, max_labeled_demos=5) -eval_score = evaluate(compiled_prompt_opt, devset=valset[:EVAL_NUM], **kwargs) -``` - -Note that this teleprompter takes in the following parameters: - -* prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (i.e., dspy.settings.configure(lm=task_model)). -* task_model: The model used for running your task. When unspecified, defaults to the model set in settings (i.e., dspy.settings.configure(lm=task_model)). -* teacher_settings: The settings used for the teacher model. When unspecified, defaults to the settings set in settings (i.e., dspy.settings.configure(lm=task_model)). - The teacher settings are used to generate the fewshot examples. This is the LLM/settings to use as a task model for the bootstrapping runs. - Typically you would want to use a model of equal or greater quality to your task model. -* metric: The task metric used for optimization. -* num_candidates: The number of new prompts and sets of fewshot examples to generate and evaluate. Default=10. -* init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.0. -* verbose: Tells the method whether or not to print intermediate steps. -* track_stats: Tells the method whether or not to track statistics about the optimization process. - If True, the method will track a dictionary with a key corresponding to the trial number, - and a value containing a dict with the following keys: - * program: the program being evaluated at a given trial - * score: the last average evaluated score for the program - * pruned: whether or not this program was pruned - This information will be returned as attributes of the best program. -* log_dir: The directory to save logs and other information to. If unspecified, no logs will be saved. -* view_data_batch_size: The number of examples to view in the data batch when producing the dataset summary. Default=10. -* minibatch_size: The size of the minibatch to use when evaluating the program if using minibatched evaluations. Default=25. -* minibatch_full_eval_steps: The number of steps to take before doing a full evaluation of the program if using minibatched evaluations. Default=10. -* metric_threshold: If the metric yields a numerical value, then check it against this threshold when deciding whether or not to accept a bootstrap example. -""" - +# Constants BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT = 3 LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT = 0 +MIN_MINIBATCH_SIZE = 50 + +AUTO_RUN_SETTINGS = { + "light": {"num_trials": 7, "val_size": 100}, + "medium": {"num_trials": 25, "val_size": 300}, + "heavy": {"num_trials": 50, "val_size": 1000}, +} + +# ANSI escape codes for colors +YELLOW = "\033[93m" +GREEN = "\033[92m" +BLUE = "\033[94m" +BOLD = "\033[1m" +ENDC = "\033[0m" # Resets the color to default + class MIPROv2(Teleprompter): def __init__( self, - metric, - prompt_model=dspy.settings.lm, - task_model=dspy.settings.lm, - teacher_settings={}, - num_candidates=10, - num_threads=6, - max_errors=10, - init_temperature=0.5, - verbose=False, - track_stats=True, - log_dir=None, - metric_threshold=None, + metric: Callable, + prompt_model: Optional[Any] = None, + task_model: Optional[Any] = None, + teacher_settings: Dict = {}, + max_bootstrapped_demos: int = 4, + max_labeled_demos: int = 16, + auto: Optional[str] = None, + num_candidates: int = 10, + num_threads: int = 6, + max_errors: int = 10, + seed: int = 9, + init_temperature: float = 0.5, + verbose: bool = False, + track_stats: bool = True, + log_dir: Optional[str] = None, + metric_threshold: Optional[float] = None, ): + # Validate 'auto' parameter + allowed_modes = {None, "light", "medium", "heavy"} + if auto not in allowed_modes: + raise ValueError( + f"Invalid value for auto: {auto}. Must be one of {allowed_modes}." + ) + self.auto = auto + self.num_candidates = num_candidates self.metric = metric self.init_temperature = init_temperature - self.task_model = task_model - self.prompt_model = prompt_model + self.task_model = task_model if task_model else dspy.settings.lm + self.prompt_model = prompt_model if prompt_model else dspy.settings.lm + self.max_bootstrapped_demos = max_bootstrapped_demos + self.max_labeled_demos = max_labeled_demos self.verbose = verbose self.track_stats = track_stats self.log_dir = log_dir @@ -93,411 +85,647 @@ def __init__( self.num_threads = num_threads self.max_errors = max_errors self.metric_threshold = metric_threshold + self.seed = seed + self.rng = None def compile( self, - student, + student: Any, *, - trainset, - valset=None, - num_trials=30, - max_bootstrapped_demos=4, - max_labeled_demos=16, - seed=9, - minibatch=True, - minibatch_size=25, - minibatch_full_eval_steps=10, - program_aware_proposer=True, - data_aware_proposer=True, - view_data_batch_size=10, - tip_aware_proposer=True, - fewshot_aware_proposer=True, - requires_permission_to_run=True, - ): - # Define ANSI escape codes for colors - YELLOW = "\033[93m" - GREEN = "\033[92m" - BLUE = "\033[94m" - BOLD = "\033[1m" - ENDC = "\033[0m" # Resets the color to default + trainset: List, + valset: Optional[List] = None, + num_trials: int = 30, + max_bootstrapped_demos: Optional[int] = None, + max_labeled_demos: Optional[int] = None, + seed: Optional[int] = None, + minibatch: bool = True, + minibatch_size: int = 25, + minibatch_full_eval_steps: int = 10, + program_aware_proposer: bool = True, + data_aware_proposer: bool = True, + view_data_batch_size: int = 10, + tip_aware_proposer: bool = True, + fewshot_aware_proposer: bool = True, + requires_permission_to_run: bool = True, + ) -> Any: + # Set random seeds + seed = seed or self.seed + self._set_random_seeds(seed) + + # Update max demos if specified + if max_bootstrapped_demos is not None: + self.max_bootstrapped_demos = max_bootstrapped_demos + if max_labeled_demos is not None: + self.max_labeled_demos = max_labeled_demos + + # Set training & validation sets + trainset, valset = self._set_and_validate_datasets(trainset, valset) + + # Set hyperparameters based on run mode (if set) + zeroshot_opt = (self.max_bootstrapped_demos == 0) and ( + self.max_labeled_demos == 0 + ) + num_trials, valset, minibatch = self._set_hyperparams_from_run_mode( + student, num_trials, minibatch, zeroshot_opt, valset + ) + + if self.auto: + self._print_auto_run_settings(num_trials, minibatch, valset) - random.seed(seed) + if minibatch and minibatch_size > len(valset): + raise ValueError( + f"Minibatch size cannot exceed the size of the valset. Valset size: {len(valset)}." + ) - # Validate inputs + # Estimate LM calls and get user confirmation + if requires_permission_to_run: + if not self._get_user_confirmation( + student, + num_trials, + minibatch, + minibatch_size, + minibatch_full_eval_steps, + valset, + program_aware_proposer, + ): + print("Compilation aborted by the user.") + return student # Return the original student program + + # Initialize program and evaluator + program = student.deepcopy() + evaluate = Evaluate( + devset=valset, + metric=self.metric, + num_threads=self.num_threads, + max_errors=self.max_errors, + display_table=False, + display_progress=True, + ) + + # Step 1: Bootstrap few-shot examples + demo_candidates = self._bootstrap_fewshot_examples(program, trainset, seed) + + # Step 2: Propose instruction candidates + instruction_candidates = self._propose_instructions( + program, + trainset, + demo_candidates, + view_data_batch_size, + program_aware_proposer, + data_aware_proposer, + tip_aware_proposer, + fewshot_aware_proposer, + ) + + # If zero-shot, discard demos + if zeroshot_opt: + demo_candidates = None + + # Step 3: Find optimal prompt parameters + best_program = self._optimize_prompt_parameters( + program, + instruction_candidates, + demo_candidates, + evaluate, + valset, + num_trials, + minibatch, + minibatch_size, + minibatch_full_eval_steps, + seed, + ) + + return best_program + + def _set_random_seeds(self, + seed + ): + self.rng = random.Random(seed) + np.random.seed(seed) + + def _set_hyperparams_from_run_mode( + self, + program: Any, + num_trials: int, + minibatch: bool, + zeroshot_opt: bool, + valset: List, + ) -> Tuple[int, List, bool]: + if self.auto is None: + return num_trials, valset, minibatch + + num_vars = len(program.predictors()) + if not zeroshot_opt: + num_vars *= 2 # Account for few-shot examples + instruction variables + + auto_settings = AUTO_RUN_SETTINGS[self.auto] + num_trials = auto_settings["num_trials"] + valset = create_minibatch(valset, batch_size=auto_settings["val_size"], rng=self.rng) + minibatch = len(valset) > MIN_MINIBATCH_SIZE + self.num_candidates = int( + np.round(np.min([num_trials * num_vars, (1.5 * num_trials) / num_vars])) + ) + + return num_trials, valset, minibatch + + def _set_and_validate_datasets(self, trainset: List, valset: Optional[List]): if not trainset: raise ValueError("Trainset cannot be empty.") - if not valset: + if valset is None: if len(trainset) < 2: - raise ValueError("Trainset must have at least 2 examples if no valset specified, or at least 1 example with external validation set.") - - valset_size = min(500, max(1, int(len(trainset) * 0.80))) # 80% of trainset, capped at 500 + raise ValueError( + "Trainset must have at least 2 examples if no valset specified." + ) + valset_size = min(500, max(1, int(len(trainset) * 0.80))) cutoff = len(trainset) - valset_size valset = trainset[cutoff:] trainset = trainset[:cutoff] - else: if len(valset) < 1: - raise ValueError("Validation set must have at least 1 example if specified.") - - if minibatch and minibatch_size > len(valset): - raise ValueError(f"Minibatch size cannot exceed the size of the valset. Note that your validation set contains {len(valset)} examples. Your train set contains {len(trainset)} examples.") - - if minibatch and num_trials < minibatch_full_eval_steps: - raise ValueError(f"Number of trials (num_trials={num_trials}) must be greater than or equal to the number of minibatch full eval steps (minibatch_full_eval_steps={minibatch_full_eval_steps}).") - - estimated_prompt_model_calls = 10 + self.num_candidates * len( - student.predictors(), - ) + (0 if not program_aware_proposer else len(student.predictors()) + 1) # num data summary calls + N * P + (P + 1) - - prompt_model_line = "" - if not program_aware_proposer: - prompt_model_line = f"""{YELLOW}- Prompt Model: {BLUE}{BOLD}10{ENDC}{YELLOW} data summarizer calls + {BLUE}{BOLD}{self.num_candidates}{ENDC}{YELLOW} * {BLUE}{BOLD}{len(student.predictors())}{ENDC}{YELLOW} lm calls in program = {BLUE}{BOLD}{estimated_prompt_model_calls}{ENDC}{YELLOW} prompt model calls{ENDC}""" - else: - prompt_model_line = f"""{YELLOW}- Prompt Model: {BLUE}{BOLD}10{ENDC}{YELLOW} data summarizer calls + {BLUE}{BOLD}{self.num_candidates}{ENDC}{YELLOW} * {BLUE}{BOLD}{len(student.predictors())}{ENDC}{YELLOW} lm calls in program + ({BLUE}{BOLD}{len(student.predictors()) + 1}{ENDC}{YELLOW}) lm calls in program aware proposer = {BLUE}{BOLD}{estimated_prompt_model_calls}{ENDC}{YELLOW} prompt model calls{ENDC}""" + raise ValueError("Validation set must have at least 1 example.") + + return trainset, valset + + def _print_auto_run_settings(self, num_trials: int, minibatch: bool, valset: List): + print( + f"\nRUNNING WITH THE FOLLOWING {self.auto.upper()} AUTO RUN SETTINGS:" + f"\nnum_trials: {num_trials}" + f"\nminibatch: {minibatch}" + f"\nnum_candidates: {self.num_candidates}" + f"\nvalset size: {len(valset)}\n" + ) - estimated_task_model_calls_wo_module_calls = 0 - task_model_line = "" + def _estimate_lm_calls( + self, + program: Any, + num_trials: int, + minibatch: bool, + minibatch_size: int, + minibatch_full_eval_steps: int, + valset: List, + program_aware_proposer: bool, + ) -> Tuple[str, str]: + num_predictors = len(program.predictors()) + + # Estimate prompt model calls + estimated_prompt_model_calls = ( + 10 # Data summarizer calls + + self.num_candidates * num_predictors # Candidate generation + + ( + num_predictors + 1 if program_aware_proposer else 0 + ) # Program-aware proposer + ) + prompt_model_line = ( + f"{YELLOW}- Prompt Generation: {BLUE}{BOLD}10{ENDC}{YELLOW} data summarizer calls + " + f"{BLUE}{BOLD}{self.num_candidates}{ENDC}{YELLOW} * " + f"{BLUE}{BOLD}{num_predictors}{ENDC}{YELLOW} lm calls in program " + f"+ ({BLUE}{BOLD}{num_predictors + 1}{ENDC}{YELLOW}) lm calls in program-aware proposer " + f"= {BLUE}{BOLD}{estimated_prompt_model_calls}{ENDC}{YELLOW} prompt model calls{ENDC}" + ) + + # Estimate task model calls if not minibatch: - estimated_task_model_calls_wo_module_calls = len(trainset) * num_trials # M * T * P - task_model_line = f"""{YELLOW}- Task Model: {BLUE}{BOLD}{len(valset)}{ENDC}{YELLOW} examples in val set * {BLUE}{BOLD}{num_trials}{ENDC}{YELLOW} batches * {BLUE}{BOLD}# of LM calls in your program{ENDC}{YELLOW} = ({BLUE}{BOLD}{estimated_task_model_calls_wo_module_calls} * # of LM calls in your program{ENDC}{YELLOW}) task model calls{ENDC}""" + estimated_task_model_calls = len(valset) * num_trials + task_model_line = ( + f"{YELLOW}- Program Evaluation: {BLUE}{BOLD}{len(valset)}{ENDC}{YELLOW} examples in val set * " + f"{BLUE}{BOLD}{num_trials}{ENDC}{YELLOW} batches = " + f"{BLUE}{BOLD}{estimated_task_model_calls}{ENDC}{YELLOW} LM program calls{ENDC}" + ) else: - estimated_task_model_calls_wo_module_calls = minibatch_size * num_trials + (len(trainset) * (num_trials // minibatch_full_eval_steps)) # B * T * P - task_model_line = f"""{YELLOW}- Task Model: {BLUE}{BOLD}{minibatch_size}{ENDC}{YELLOW} examples in minibatch * {BLUE}{BOLD}{num_trials}{ENDC}{YELLOW} batches + {BLUE}{BOLD}{len(valset)}{ENDC}{YELLOW} examples in val set * {BLUE}{BOLD}{num_trials // minibatch_full_eval_steps}{ENDC}{YELLOW} full evals = {BLUE}{BOLD}{estimated_task_model_calls_wo_module_calls}{ENDC}{YELLOW} task model calls{ENDC}""" - + full_eval_steps = num_trials // minibatch_full_eval_steps + 1 + estimated_task_model_calls = ( + minibatch_size * num_trials + len(valset) * full_eval_steps + ) + task_model_line = ( + f"{YELLOW}- Program Evaluation: {BLUE}{BOLD}{minibatch_size}{ENDC}{YELLOW} examples in minibatch * " + f"{BLUE}{BOLD}{num_trials}{ENDC}{YELLOW} batches + " + f"{BLUE}{BOLD}{len(valset)}{ENDC}{YELLOW} examples in val set * " + f"{BLUE}{BOLD}{full_eval_steps}{ENDC}{YELLOW} full evals = " + f"{BLUE}{BOLD}{estimated_task_model_calls}{ENDC}{YELLOW} LM Program calls{ENDC}" + ) + + return prompt_model_line, task_model_line - user_message = textwrap.dedent(f"""\ + def _get_user_confirmation( + self, + program: Any, + num_trials: int, + minibatch: bool, + minibatch_size: int, + minibatch_full_eval_steps: int, + valset: List, + program_aware_proposer: bool, + ) -> bool: + prompt_model_line, task_model_line = self._estimate_lm_calls( + program, + num_trials, + minibatch, + minibatch_size, + minibatch_full_eval_steps, + valset, + program_aware_proposer, + ) + + user_message = textwrap.dedent( + f"""\ {YELLOW}{BOLD}Projected Language Model (LM) Calls{ENDC} - Please be advised that based on the parameters you have set, the maximum number of LM calls is projected as follows: + Based on the parameters you have set, the maximum number of LM calls is projected as follows: - {prompt_model_line} {task_model_line} {YELLOW}{BOLD}Estimated Cost Calculation:{ENDC} {YELLOW}Total Cost = (Number of calls to task model * (Avg Input Token Length per Call * Task Model Price per Input Token + Avg Output Token Length per Call * Task Model Price per Output Token) - + (Number of calls to prompt model * (Avg Input Token Length per Call * Task Prompt Price per Input Token + Avg Output Token Length per Call * Prompt Model Price per Output Token).{ENDC} + + (Number of program calls * (Avg Input Token Length per Call * Task Prompt Price per Input Token + Avg Output Token Length per Call * Prompt Model Price per Output Token).{ENDC} For a preliminary estimate of potential costs, we recommend you perform your own calculations based on the task and prompt models you intend to use. If the projected costs exceed your budget or expectations, you may consider: {YELLOW}- Reducing the number of trials (`num_trials`), the size of the valset, or the number of LM calls in your program.{ENDC} - {YELLOW}- Using a cheaper task model to optimize the prompt.{ENDC}\n""") + {YELLOW}- Using a cheaper task model to optimize the prompt.{ENDC} + {YELLOW}- Setting `minibatch=True` if you haven't already.{ENDC}\n""" + ) - user_confirmation_message = textwrap.dedent(f"""\ + user_confirmation_message = textwrap.dedent( + f"""\ To proceed with the execution of this program, please confirm by typing {BLUE}'y'{ENDC} for yes or {BLUE}'n'{ENDC} for no. If you would like to bypass this confirmation step in future executions, set the {YELLOW}`requires_permission_to_run`{ENDC} flag to {YELLOW}`False`{ENDC} when calling compile. {YELLOW}Awaiting your input...{ENDC} - """) - - if requires_permission_to_run: print(user_message) - - sys.stdout.flush() # Flush the output buffer to force the message to print + """ + ) + + print(user_message) + sys.stdout.flush() + print(user_confirmation_message) + user_input = input("Do you wish to continue? (y/n): ").strip().lower() + return user_input == "y" + + def _bootstrap_fewshot_examples( + self, program: Any, trainset: List, seed: int + ) -> Optional[List]: + print("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==") + if self.max_bootstrapped_demos > 0: + print( + "These will be used as few-shot example candidates for our program and for creating instructions.\n" + ) + else: + print("These will be used for informing instruction proposal.\n") - run = True - # TODO: make sure these estimates are good for mini-batching - if requires_permission_to_run: - print(user_confirmation_message) - user_input = input("Do you wish to continue? (y/n): ").strip().lower() - if user_input != "y": - print("Compilation aborted by the user.") - run = False + print(f"Bootstrapping N={self.num_candidates} sets of demonstrations...") - if run: - # Setup random seeds - random.seed(seed) - np.random.seed(seed) + zeroshot = self.max_bootstrapped_demos == 0 and self.max_labeled_demos == 0 - # Set up program and evaluation function - program = student.deepcopy() - evaluate = Evaluate( - devset=valset, + try: + demo_candidates = create_n_fewshot_demo_sets( + student=program, + num_candidate_sets=self.num_candidates, + trainset=trainset, + max_labeled_demos=( + LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT + if zeroshot + else self.max_labeled_demos + ), + max_bootstrapped_demos=( + BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT + if zeroshot + else self.max_bootstrapped_demos + ), metric=self.metric, - num_threads=self.num_threads, max_errors=self.max_errors, - display_table=False, - display_progress=True, + teacher_settings=self.teacher_settings, + seed=seed, + metric_threshold=self.metric_threshold, + rng=self.rng, ) + except Exception as e: + print(f"Error generating few-shot examples: {e}") + print("Running without few-shot examples.") + demo_candidates = None + + return demo_candidates - # Determine the number of fewshot examples to use to generate demos for prompt - if max_bootstrapped_demos == 0 and max_labeled_demos == 0: - max_bootstrapped_demos_for_candidate_gen = BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT - max_labeled_demos_for_candidate_gen = LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT + def _propose_instructions( + self, + program: Any, + trainset: List, + demo_candidates: Optional[List], + view_data_batch_size: int, + program_aware_proposer: bool, + data_aware_proposer: bool, + tip_aware_proposer: bool, + fewshot_aware_proposer: bool, + ) -> Dict[int, List[str]]: + print("\n==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==") + print( + "We will use the few-shot examples from the previous step, a generated dataset summary, a summary of the program code, and a randomly selected prompting tip to propose instructions." + ) + + proposer = GroundedProposer( + program=program, + trainset=trainset, + prompt_model=self.prompt_model, + view_data_batch_size=view_data_batch_size, + program_aware=program_aware_proposer, + use_dataset_summary=data_aware_proposer, + use_task_demos=fewshot_aware_proposer, + use_tip=tip_aware_proposer, + set_tip_randomly=tip_aware_proposer, + use_instruct_history=False, + set_history_randomly=False, + verbose=self.verbose, + rng=self.rng + ) + + print("\nProposing instructions...\n") + instruction_candidates = proposer.propose_instructions_for_program( + trainset=trainset, + program=program, + demo_candidates=demo_candidates, + N=self.num_candidates, + T=self.init_temperature, + trial_logs={}, + ) + + for i, pred in enumerate(program.predictors()): + print(f"Proposed Instructions for Predictor {i}:\n") + instruction_candidates[i][0] = get_signature(pred).instructions + for j, instruction in enumerate(instruction_candidates[i]): + print(f"{j}: {instruction}\n") + print("\n") + + return instruction_candidates + + def _optimize_prompt_parameters( + self, + program: Any, + instruction_candidates: Dict[int, List[str]], + demo_candidates: Optional[List], + evaluate: Evaluate, + valset: List, + num_trials: int, + minibatch: bool, + minibatch_size: int, + minibatch_full_eval_steps: int, + seed: int, + ) -> Optional[Any]: + print("Evaluating the default program...\n") + default_score = eval_candidate_program(len(valset), valset, program, evaluate, self.rng) + print(f"Default program score: {default_score}\n") + + # Initialize optimization variables + best_score = default_score + best_program = program.deepcopy() + trial_logs = {} + total_eval_calls = 0 + if minibatch: + scores = [] + else: + scores = [default_score] + full_eval_scores = [default_score] + param_score_dict = defaultdict(list) + fully_evaled_param_combos = {} + + # Define the objective function + def objective(trial): + nonlocal program, best_program, best_score, trial_logs, total_eval_calls, scores, full_eval_scores + + trial_num = trial.number + 1 + if minibatch: + print(f"== Minibatch Trial {trial_num} / {num_trials} ==") else: - max_bootstrapped_demos_for_candidate_gen = max_bootstrapped_demos - max_labeled_demos_for_candidate_gen = max_labeled_demos - - # Generate N few shot example sets (these will inform instruction creation, and be used as few-shot examples in our prompt) - print("Beginning MIPROv2 optimization process...") - print("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==") - if max_bootstrapped_demos > 0: - print("These will be used for as few-shot examples candidates for our program and for creating instructions.\n") + print(f"===== Trial {trial_num} / {num_trials} =====") + + trial_logs[trial_num] = {} + + # Create a new candidate program + candidate_program = program.deepcopy() + + # Choose instructions and demos, insert them into the program + chosen_params = self._select_and_insert_instructions_and_demos( + candidate_program, + instruction_candidates, + demo_candidates, + trial, + trial_logs, + trial_num, + ) + + # Log assembled program + if self.verbose: + print("Evaluating the following candidate program...\n") + print_full_program(candidate_program) + + # Save the candidate program + trial_logs[trial_num]["program_path"] = save_candidate_program( + candidate_program, self.log_dir, trial_num + ) + + # Evaluate the candidate program + batch_size = minibatch_size if minibatch else len(valset) + score = eval_candidate_program( + batch_size, valset, candidate_program, evaluate, self.rng + ) + + # Update best score and program + if not minibatch and score > best_score: + best_score = score + best_program = candidate_program.deepcopy() + print(f"{GREEN}Best full score so far!{ENDC} Score: {score}") + + # Log evaluation results + scores.append(score) + if minibatch: + self._log_minibatch_eval( + score, + best_score, + batch_size, + chosen_params, + scores, + full_eval_scores, + trial, + num_trials, + ) else: - print("These will be used for informing instruction proposal.\n") - print(f"Bootstrapping N={self.num_candidates} sets of demonstrations...") - try: - demo_candidates = create_n_fewshot_demo_sets( - student=program, - num_candidate_sets=self.num_candidates, - trainset=trainset, - max_labeled_demos=max_labeled_demos_for_candidate_gen, - max_bootstrapped_demos=max_bootstrapped_demos_for_candidate_gen, - metric=self.metric, - max_errors=self.max_errors, - teacher_settings=self.teacher_settings, - seed=seed, - metric_threshold=self.metric_threshold, + self._log_normal_eval( + score, best_score, chosen_params, scores, trial, num_trials ) - except Exception as e: - print(f"Error generating fewshot examples: {e}") - print("Running without fewshot examples.") - demo_candidates = None - - # Generate N candidate instructions - - # Setup our proposer - print("\n==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==") - print("In this step, by default we will use the few-shot examples from the previous step, a generated dataset summary, a summary of the program code, and a randomly selected prompting tip to propose instructions.") - - proposer = GroundedProposer( - program=program, - trainset=trainset, - prompt_model=self.prompt_model, - # program_code_string=self.program_code_string, - view_data_batch_size=view_data_batch_size, - program_aware=program_aware_proposer, - use_dataset_summary=data_aware_proposer, - use_task_demos=fewshot_aware_proposer, - use_tip=tip_aware_proposer, - set_tip_randomly=tip_aware_proposer, - use_instruct_history=False, - set_history_randomly=False, - verbose = self.verbose, + categorical_key = ",".join(map(str, chosen_params)) + param_score_dict[categorical_key].append( + (score, candidate_program), ) - - print("\nProposing instructions...\n") - instruction_candidates = proposer.propose_instructions_for_program( - trainset=trainset, - program=program, - demo_candidates=demo_candidates, - N=self.num_candidates, - T=self.init_temperature, - trial_logs={}, - ) - for i, pred in enumerate(program.predictors()): - print(f"Proposed Instructions for Predictor {i}:\n") - instruction_candidates[i][0] = get_signature(pred).instructions - for j, instruction in enumerate(instruction_candidates[i]): - print(f"{j}: {instruction}\n") - print("\n") - - # If we're doing zero-shot, reset demo_candidates to none now that we've used them for instruction proposal - if max_bootstrapped_demos == 0 and max_labeled_demos == 0: - demo_candidates = None - - # Initialize variables to track during the optimization process - best_scoring_trial = 0 - trial_logs = {} - total_eval_calls = 0 - param_score_dict = defaultdict(list) # Dictionaries of paramater combinations we've tried, and their associated scores - fully_evaled_param_combos = {} # List of the parameter combinations we've done full evals of - - # Evaluate the default program - print("Evaluating the default program...\n") - default_score = eval_candidate_program(len(valset), valset, program, evaluate) - print(f"Default program score: {default_score}\n") - - best_score = default_score - best_program = program.deepcopy() - - - # Define our trial objective - def create_objective( - baseline_program, - instruction_candidates, - demo_candidates, - evaluate, - valset, + trial_logs[trial_num]["num_eval_calls"] = batch_size + trial_logs[trial_num]["full_eval"] = batch_size >= len(valset) + trial_logs[trial_num]["score"] = score + total_eval_calls += batch_size + trial_logs[trial_num]["total_eval_calls_so_far"] = total_eval_calls + trial_logs[trial_num]["program"] = candidate_program.deepcopy() + + # If minibatch, perform full evaluation at intervals + if minibatch and ( + (trial_num % minibatch_full_eval_steps == 0) + or (trial_num == num_trials) ): - def objective(trial): - nonlocal best_program, best_score, best_scoring_trial, trial_logs, total_eval_calls # Allow access to the outer variables - - # Kick off trial - if minibatch: - print(f"== Minibatch Trial {trial.number+1} / {num_trials} ==") - else: - print(f"===== Trial {trial.number+1} / {num_trials} =====") - trial_logs[trial.number+1] = {} - - # Create a new candidate program - candidate_program = baseline_program.deepcopy() - - # Choose set of instructions & demos to use for each predictor - chosen_params = [] - for i, p_new in enumerate(candidate_program.predictors()): - - # Get instruction candidates / demos for our given predictor - p_instruction_candidates = instruction_candidates[i] - if demo_candidates: - p_demo_candidates = demo_candidates[i] - - # Suggest the index of the instruction / demo candidate to use in our trial - instruction_idx = trial.suggest_categorical( - f"{i}_predictor_instruction", - range(len(p_instruction_candidates)), - ) - # chosen_params.append(instruction_idx) - chosen_params.append(f"Predictor {i+1}: Instruction {instruction_idx}") - if demo_candidates: - demos_idx = trial.suggest_categorical( - f"{i}_predictor_demos", range(len(p_demo_candidates)), - ) - chosen_params.append(f"Predictor {i+1}: Few-Shot Set {demos_idx}") - - # Log the selected instruction / demo candidate - trial_logs[trial.number+1][ - f"{i}_predictor_instruction" - ] = instruction_idx - if demo_candidates: - trial_logs[trial.number+1][f"{i}_predictor_demos"] = demos_idx - - dspy.logger.debug(f"instruction_idx {instruction_idx}") - if demo_candidates: - dspy.logger.debug(f"demos_idx {demos_idx}") - - # Set the instruction - selected_instruction = p_instruction_candidates[instruction_idx] - updated_signature = get_signature(p_new).with_instructions( - selected_instruction, - ) - set_signature(p_new, updated_signature) - - # Set the demos - if demo_candidates: - p_new.demos = p_demo_candidates[demos_idx] - - # Log assembled program - if self.verbose: print("Evaluating the following candidate program...\n") - if self.verbose: print_full_program(candidate_program) - - # Save the candidate program - trial_logs[trial.number+1]["program_path"] = save_candidate_program( - candidate_program, self.log_dir, trial.number+1, - ) - - trial_logs[trial.number+1]["num_eval_calls"] = 0 - - # Evaluate the candidate program with relevant batch size - batch_size = minibatch_size if minibatch else len(valset) - - score = eval_candidate_program( - batch_size, valset, candidate_program, - evaluate, - ) - - # Print out a full trace of the program in use - if self.verbose: - print("Full trace of prompts in use on an example...") - get_task_model_history_for_full_example( - candidate_program, self.task_model, valset, evaluate, - ) - - # Log relevant information - categorical_key = ",".join(map(str, chosen_params)) - param_score_dict[categorical_key].append( - (score, candidate_program), - ) - trial_logs[trial.number+1]["num_eval_calls"] = batch_size - trial_logs[trial.number+1]["full_eval"] = batch_size >= len(valset) - trial_logs[trial.number+1]["score"] = score - trial_logs[trial.number+1]["pruned"] = False - total_eval_calls += trial_logs[trial.number+1]["num_eval_calls"] - trial_logs[trial.number+1]["total_eval_calls_so_far"] = total_eval_calls - trial_logs[trial.number+1]["program"] = candidate_program.deepcopy() - - # If this score was from a full evaluation, update the best program if the new score is better - best_score_updated = False - if score > best_score and trial_logs[trial.number+1]["full_eval"] and not minibatch: - best_score = score - best_scoring_trial = trial.number+1 - best_program = candidate_program.deepcopy() - best_score_updated = True - - if minibatch: - print(f"Score: {score} on minibatch of size {batch_size} with parameters {chosen_params}.\n\n") - else: - print(f"Score: {score} with parameters {chosen_params}.") - if best_score_updated: - print(f"{GREEN}New best score updated!{ENDC} Score: {best_score} on trial {best_scoring_trial}.\n\n") - else: - print(f"Best score so far: {best_score} on trial {best_scoring_trial}.\n\n") - - - # If we're doing minibatching, check to see if it's time to do a full eval - if minibatch and (((trial.number+1) % minibatch_full_eval_steps == 0) or (trial.number+1 == num_trials)): - print(f"===== Full Eval {len(fully_evaled_param_combos)+1} =====") - - # Save old information as the minibatch version - trial_logs[trial.number+1]["mb_score"] = score - trial_logs[trial.number+1]["mb_program_path"] = trial_logs[trial.number+1]["program_path"] - - # Identify our best program (based on mean of scores so far, and do a full eval on it) - highest_mean_program, mean, combo_key = get_program_with_highest_avg_score(param_score_dict, fully_evaled_param_combos) - - if trial.number+1 // minibatch_full_eval_steps > 0: - print(f"Doing full eval on next top averaging program (Avg Score: {mean}) so far from mini-batch trials...") - else: - print(f"Doing full eval on top averaging program (Avg Score: {mean}) so far from mini-batch trials...") - full_val_score = eval_candidate_program( - len(valset), valset, highest_mean_program, evaluate, - ) - - # Log relevant information - fully_evaled_param_combos[combo_key] = {"program":highest_mean_program, "score": full_val_score} - total_eval_calls += len(valset) - trial_logs[trial.number+1]["total_eval_calls_so_far"] = total_eval_calls - trial_logs[trial.number+1]["full_eval"] = True - trial_logs[trial.number+1]["program_path"] = save_candidate_program( - program=highest_mean_program, log_dir=self.log_dir, trial_num=trial.number+1, note="full_eval", - ) - trial_logs[trial.number+1]["score"] = full_val_score - - if full_val_score > best_score: - print(f"{GREEN}Best full eval score so far!{ENDC} Score: {full_val_score}") - best_score = full_val_score - best_scoring_trial = trial.number+1 - best_program = highest_mean_program.deepcopy() - best_score_updated = True - else: - print(f"Full eval score: {full_val_score}") - print(f"Best full eval score so far: {best_score}") - print("=======================\n\n") - - return score - - return objective - - # Run the trial - optuna.logging.set_verbosity(optuna.logging.WARNING) - print("==> STEP 3: FINDING OPTIMAL PROMPT PARAMETERS <==") - print("In this step, we will evaluate the program over a series of trials with different combinations of instructions and few-shot examples to find the optimal combination. Bayesian Optimization will be used for this search process.\n") - objective_function = create_objective( - program, instruction_candidates, demo_candidates, evaluate, valset, + best_score, best_program = self._perform_full_evaluation( + trial_num, + param_score_dict, + fully_evaled_param_combos, + evaluate, + valset, + trial_logs, + total_eval_calls, + full_eval_scores, + best_score, + best_program, + ) + + return score + + # Run optimization + optuna.logging.set_verbosity(optuna.logging.WARNING) + print("==> STEP 3: FINDING OPTIMAL PROMPT PARAMETERS <==") + print( + "We will evaluate the program over a series of trials with different combinations of instructions and few-shot examples to find the optimal combination using Bayesian Optimization.\n" + ) + + sampler = optuna.samplers.TPESampler(seed=seed, multivariate=True) + study = optuna.create_study(direction="maximize", sampler=sampler) + study.optimize(objective, n_trials=num_trials) + + # Attach logs to best program + if best_program is not None and self.track_stats: + best_program.trial_logs = trial_logs + best_program.score = best_score + best_program.prompt_model_total_calls = self.prompt_model_total_calls + best_program.total_calls = self.total_calls + + print(f"Returning best identified program with score {best_score}!") + + return best_program + + def _log_minibatch_eval( + self, + score, + best_score, + batch_size, + chosen_params, + scores, + full_eval_scores, + trial, + num_trials, + ): + print( + f"Score: {score} on minibatch of size {batch_size} with parameters {chosen_params}." + ) + print(f"Minibatch scores so far: {'['+', '.join([f'{s}' for s in scores])+']'}") + trajectory = "[" + ", ".join([f"{s}" for s in full_eval_scores]) + "]" + print(f"Full eval scores so far: {trajectory}") + print(f"Best full score so far: {best_score}") + print( + f'{"="*len(f"== Minibatch Trial {trial.number+1} / {num_trials} ==")}\n\n' + ) + + def _log_normal_eval( + self, score, best_score, chosen_params, scores, trial, num_trials + ): + print(f"Score: {score} with parameters {chosen_params}.") + print(f"Scores so far: {'['+', '.join([f'{s}' for s in scores])+']'}") + print(f"Best score so far: {best_score}") + print(f'{"="*len(f"===== Trial {trial.number+1} / {num_trials} =====")}\n\n') + + def _select_and_insert_instructions_and_demos( + self, + candidate_program: Any, + instruction_candidates: Dict[int, List[str]], + demo_candidates: Optional[List], + trial: optuna.trial.Trial, + trial_logs: Dict, + trial_num: int, + ) -> List[str]: + chosen_params = [] + + for i, predictor in enumerate(candidate_program.predictors()): + # Select instruction + instruction_idx = trial.suggest_categorical( + f"{i}_predictor_instruction", range(len(instruction_candidates[i])) + ) + selected_instruction = instruction_candidates[i][instruction_idx] + updated_signature = get_signature(predictor).with_instructions( + selected_instruction ) - sampler = optuna.samplers.TPESampler(seed=seed, multivariate=True) - study = optuna.create_study(direction="maximize", sampler=sampler) - _ = study.optimize(objective_function, n_trials=num_trials) + set_signature(predictor, updated_signature) + trial_logs[trial_num][f"{i}_predictor_instruction"] = instruction_idx + chosen_params.append(f"Predictor {i+1}: Instruction {instruction_idx}") + + # Select demos if available + if demo_candidates: + demos_idx = trial.suggest_categorical( + f"{i}_predictor_demos", range(len(demo_candidates[i])) + ) + predictor.demos = demo_candidates[i][demos_idx] + trial_logs[trial_num][f"{i}_predictor_demos"] = demos_idx + chosen_params.append(f"Predictor {i+1}: Few-Shot Set {demos_idx}") - if best_program is not None and self.track_stats: - best_program.trial_logs = trial_logs - best_program.score = best_score - best_program.prompt_model_total_calls = self.prompt_model_total_calls - best_program.total_calls = self.total_calls + return chosen_params - return best_program + def _perform_full_evaluation( + self, + trial_num: int, + param_score_dict: Dict, + fully_evaled_param_combos: Dict, + evaluate: Evaluate, + valset: List, + trial_logs: Dict, + total_eval_calls: int, + full_eval_scores: List[int], + best_score: float, + best_program: Any, + ): + print(f"===== Full Eval {len(fully_evaled_param_combos)+1} =====") - return student + # Identify best program to evaluate fully + highest_mean_program, mean_score, combo_key = ( + get_program_with_highest_avg_score( + param_score_dict, fully_evaled_param_combos + ) + ) + print( + f"Doing full eval on next top averaging program (Avg Score: {mean_score}) from minibatch trials..." + ) + full_eval_score = eval_candidate_program( + len(valset), valset, highest_mean_program, evaluate, self.rng + ) + full_eval_scores.append(full_eval_score) + + # Log full evaluation results + fully_evaled_param_combos[combo_key] = { + "program": highest_mean_program, + "score": full_eval_score, + } + total_eval_calls += len(valset) + trial_logs[trial_num]["total_eval_calls_so_far"] = total_eval_calls + trial_logs[trial_num]["full_eval"] = True + trial_logs[trial_num]["program_path"] = save_candidate_program( + program=highest_mean_program, + log_dir=self.log_dir, + trial_num=trial_num, + note="full_eval", + ) + trial_logs[trial_num]["score"] = full_eval_score + + # Update best score and program if necessary + if full_eval_score > best_score: + print(f"{GREEN}New best full eval score!{ENDC} Score: {full_eval_score}") + best_score = full_eval_score + best_program = highest_mean_program.deepcopy() + trajectory = "[" + ", ".join([f"{s}" for s in full_eval_scores]) + "]" + print(f"Full eval scores so far: {trajectory}") + print(f"Best full score so far: {best_score}") + print(len(f"===== Full Eval {len(fully_evaled_param_combos)+1} =====") * "=") + print("\n") + + return best_score, best_program diff --git a/dspy/teleprompt/utils.py b/dspy/teleprompt/utils.py index ebcabe2299..3640caadf2 100644 --- a/dspy/teleprompt/utils.py +++ b/dspy/teleprompt/utils.py @@ -24,14 +24,17 @@ ### OPTIMIZER TRAINING UTILS ### -def create_minibatch(trainset, batch_size=50): +def create_minibatch(trainset, batch_size=50, rng=None): """Create a minibatch from the trainset.""" # Ensure batch_size isn't larger than the size of the dataset batch_size = min(batch_size, len(trainset)) - # Randomly sample indices for the mini-batch - sampled_indices = random.sample(range(len(trainset)), batch_size) + # If no RNG is provided, fall back to the global random instance + rng = rng or random + + # Randomly sample indices for the mini-batch using the provided rng + sampled_indices = rng.sample(range(len(trainset)), batch_size) # Create the mini-batch using the sampled indices minibatch = [trainset[i] for i in sampled_indices] @@ -39,7 +42,7 @@ def create_minibatch(trainset, batch_size=50): return minibatch -def eval_candidate_program(batch_size, trainset, candidate_program, evaluate): +def eval_candidate_program(batch_size, trainset, candidate_program, evaluate, rng=None): """Evaluate a candidate program on the trainset, using the specified batch size.""" # Evaluate on the full trainset if batch_size >= len(trainset): @@ -48,7 +51,7 @@ def eval_candidate_program(batch_size, trainset, candidate_program, evaluate): else: score = evaluate( candidate_program, - devset=create_minibatch(trainset, batch_size), + devset=create_minibatch(trainset, batch_size, rng), ) return score @@ -279,6 +282,7 @@ def create_n_fewshot_demo_sets( teacher=None, include_non_bootstrapped=True, seed=0, + rng=None ): """ This function is copied from random_search.py, and creates fewshot examples in the same way that random search does. @@ -292,17 +296,15 @@ def create_n_fewshot_demo_sets( # Initialize demo_candidates dictionary for i, _ in enumerate(student.predictors()): demo_candidates[i] = [] - - starter_seed = seed - # Shuffle the trainset with the starter seed - random.Random(starter_seed).shuffle(trainset) + + rng = rng or random.Random(seed) # Go through and create each candidate set for seed in range(-3, num_candidate_sets): print(f"Bootstrapping set {seed+4}/{num_candidate_sets+3}") - trainset2 = list(trainset) + trainset_copy = list(trainset) if seed == -3 and include_non_bootstrapped: # zero-shot @@ -316,7 +318,7 @@ def create_n_fewshot_demo_sets( # labels only teleprompter = LabeledFewShot(k=max_labeled_demos) program2 = teleprompter.compile( - student, trainset=trainset2, sample=labeled_sample, + student, trainset=trainset_copy, sample=labeled_sample, ) elif seed == -1: @@ -329,12 +331,12 @@ def create_n_fewshot_demo_sets( teacher_settings=teacher_settings, max_rounds=max_rounds, ) - program2 = program.compile(student, teacher=teacher, trainset=trainset2) + program2 = program.compile(student, teacher=teacher, trainset=trainset_copy) else: # shuffled few-shot - random.Random(seed).shuffle(trainset2) - size = random.Random(seed).randint(min_num_samples, max_bootstrapped_demos) + rng.shuffle(trainset_copy) + size = rng.randint(min_num_samples, max_bootstrapped_demos) teleprompter = BootstrapFewShot( metric=metric, @@ -347,7 +349,7 @@ def create_n_fewshot_demo_sets( ) program2 = teleprompter.compile( - student, teacher=teacher, trainset=trainset2, + student, teacher=teacher, trainset=trainset_copy, ) for i, _ in enumerate(student.predictors()):