From 96ee792ed3b5c5a8b5d5e4605863b807f5758eb2 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Thu, 14 Mar 2024 21:12:20 -0700 Subject: [PATCH 1/2] Formatting --- dspy/teleprompt/bootstrap.py | 69 ++-- dspy/teleprompt/copro_optimizer.py | 212 ++++++++----- dspy/teleprompt/ensemble.py | 8 +- dspy/teleprompt/finetune.py | 99 +++--- dspy/teleprompt/knn_fewshot.py | 8 +- dspy/teleprompt/mipro_optimizer.py | 371 ++++++++++++++-------- dspy/teleprompt/random_search.py | 67 ++-- dspy/teleprompt/signature_opt.py | 19 +- dspy/teleprompt/signature_opt_bayesian.py | 60 +++- dspy/teleprompt/teleprompt.py | 3 - dspy/teleprompt/teleprompt_optuna.py | 48 ++- dspy/teleprompt/vanilla.py | 5 +- 12 files changed, 637 insertions(+), 332 deletions(-) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 6fc9dd2ef7..86a1c0f362 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -31,7 +31,16 @@ class BootstrapFewShot(Teleprompter): - def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5): + def __init__( + self, + metric=None, + metric_threshold=None, + teacher_settings={}, + max_bootstrapped_demos=4, + max_labeled_demos=16, + max_rounds=1, + max_errors=5, + ): self.metric = metric self.metric_threshold = metric_threshold self.teacher_settings = teacher_settings @@ -39,7 +48,7 @@ def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_ self.max_bootstrapped_demos = max_bootstrapped_demos self.max_labeled_demos = max_labeled_demos self.max_rounds = max_rounds - self.max_errors= max_errors + self.max_errors = max_errors self.error_count = 0 self.error_lock = threading.Lock() @@ -59,14 +68,14 @@ def compile(self, student, *, teacher=None, trainset, valset=None): self.student._suggest_failures = 0 return self.student - + def _prepare_student_and_teacher(self, student, teacher): self.student = student.reset_copy() self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy() - assert getattr(self.student, '_compiled', False) is False, "Student must be uncompiled." + assert getattr(self.student, "_compiled", False) is False, "Student must be uncompiled." - if self.max_labeled_demos and getattr(self.teacher, '_compiled', False) is False: + if self.max_labeled_demos and getattr(self.teacher, "_compiled", False) is False: teleprompter = LabeledFewShot(k=self.max_labeled_demos) self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset) @@ -74,14 +83,18 @@ def _prepare_predictor_mappings(self): name2predictor, predictor2name = {}, {} student, teacher = self.student, self.teacher - assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors." + assert len(student.predictors()) == len( + teacher.predictors() + ), "Student and teacher must have the same number of predictors." for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()): assert name1 == name2, "Student and teacher must have the same program structure." - assert predictor1.signature.equals(predictor2.signature), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" + assert predictor1.signature.equals( + predictor2.signature + ), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" assert id(predictor1) != id(predictor2), "Student and teacher must be different objects." - name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2) + name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2) predictor2name[id(predictor1)] = name1 # FIXME(shangyint): This is an ugly hack to bind traces of @@ -89,7 +102,7 @@ def _prepare_predictor_mappings(self): # if isinstance(predictor1, Retry): # predictor2name[id(predictor1.module)] = name1 - predictor2name[id(predictor2)] = name2 + predictor2name[id(predictor2)] = name2 self.name2predictor = name2predictor self.predictor2name = predictor2name @@ -111,8 +124,8 @@ def _bootstrap(self, *, max_bootstraps=None): if success: bootstrapped[example_idx] = True - print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.') - + print(f"Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.") + # Unbootstrapped training examples self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped] @@ -123,10 +136,10 @@ def _bootstrap(self, *, max_bootstraps=None): # NOTE: Can't yet use evaluate because we need to trace *per example* # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12) # score = evaluate(self.metric, display_table=False, display_progress=True) - + def _bootstrap_one_example(self, example, round_idx=0): name2traces = self.name2traces - teacher = self.teacher #.deepcopy() + teacher = self.teacher # .deepcopy() predictor_cache = {} try: @@ -145,7 +158,7 @@ def _bootstrap_one_example(self, example, round_idx=0): for name, predictor in teacher.named_predictors(): predictor.demos = predictor_cache[name] - + if self.metric: metric_val = self.metric(example, prediction, trace) if self.metric_threshold: @@ -162,13 +175,13 @@ def _bootstrap_one_example(self, example, round_idx=0): current_error_count = self.error_count if current_error_count >= self.max_errors: raise e - print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.') - + print(f"Failed to run or to evaluate example {example} with {self.metric} due to {e}.") + if success: for step in trace: predictor, inputs, outputs = step - if 'dspy_uuid' in example: + if "dspy_uuid" in example: demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs) else: # TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case. @@ -177,16 +190,20 @@ def _bootstrap_one_example(self, example, round_idx=0): try: predictor_name = self.predictor2name[id(predictor)] except KeyError as e: - continue # FIXME: ! + continue # FIXME: ! # TODO: Look closer into this. It's a bit tricky to reproduce. - print(f'Failed to find predictor {predictor} in {self.predictor2name}.') - print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.') - print('Try restarting the notebook, or open an issue.') - raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e + print(f"Failed to find predictor {predictor} in {self.predictor2name}.") + print( + "Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells." + ) + print("Try restarting the notebook, or open an issue.") + raise KeyError( + f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}." + ) from e name2traces[predictor_name].append(demo) - + return success def _train(self): @@ -194,13 +211,13 @@ def _train(self): raw_demos = self.validation for name, predictor in self.student.named_predictors(): - augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos] - + augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos] + sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos)) sample_size = max(0, sample_size) raw_demos = rng.sample(raw_demos, sample_size) - + if dspy.settings.release >= 20230928: predictor.demos = raw_demos + augmented_demos else: diff --git a/dspy/teleprompt/copro_optimizer.py b/dspy/teleprompt/copro_optimizer.py index 9d304cf9ed..3523c73804 100644 --- a/dspy/teleprompt/copro_optimizer.py +++ b/dspy/teleprompt/copro_optimizer.py @@ -31,24 +31,41 @@ * total_calls: The total number of calls to the task metric. These statistics will be returned as attributes of the best program. """ + + class BasicGenerateInstruction(Signature): """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" basic_instruction = dspy.InputField(desc="The initial instructions before optimization") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") - proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") + proposed_prefix_for_output_field = dspy.OutputField( + desc="The string at the end of the prompt, which will help the model start solving the task" + ) + class GenerateInstructionGivenAttempts(dspy.Signature): - """You are an instruction optimizer for large language models. I will give some task instructions I've tried, along with their corresponding validation scores. The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality. + """You are an instruction optimizer for large language models. I will give some task instructions I've tried, along with their corresponding validation scores. The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality. -Your task is to propose a new instruction that will lead a good language model to perform the task even better. Don't be afraid to be creative.""" + Your task is to propose a new instruction that will lead a good language model to perform the task even better. Don't be afraid to be creative.""" + + attempted_instructions = dspy.InputField(format=dsp.passages2text) + proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") + proposed_prefix_for_output_field = dspy.OutputField( + desc="The string at the end of the prompt, which will help the model start solving the task" + ) - attempted_instructions = dspy.InputField(format=dsp.passages2text) - proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") - proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") class COPRO(Teleprompter): - def __init__(self, prompt_model=None, metric=None, breadth=10, depth=3, init_temperature=1.4, verbose=False, track_stats=False): + def __init__( + self, + prompt_model=None, + metric=None, + breadth=10, + depth=3, + init_temperature=1.4, + verbose=False, + track_stats=False, + ): if breadth <= 1: raise ValueError("Breadth must be greater than 1") self.metric = metric @@ -75,16 +92,16 @@ def _drop_duplicates(self, candidates): last_batch_score = -1 for c in candidates: repeat = False - if c['score'] == last_batch_score: + if c["score"] == last_batch_score: for c2 in last_batch: - if (self._check_candidates_equal(c, c2)): + if self._check_candidates_equal(c, c2): repeat = True break if not repeat: last_batch.append(c) else: last_batch = [c] - last_batch_score = c['score'] + last_batch_score = c["score"] if not repeat: final_candidates.append(c) return final_candidates @@ -95,32 +112,34 @@ def _print_signature(self, predictor): print(f"i: {signature.instructions}") print(f"p: {list(signature.fields.values())[-1].json_schema_extra['prefix']}") print() - + def _get_signature(self, predictor): - if (hasattr(predictor, 'extended_signature')): + if hasattr(predictor, "extended_signature"): return predictor.extended_signature - elif (hasattr(predictor, 'signature')): + elif hasattr(predictor, "signature"): return predictor.signature - + def _set_signature(self, predictor, updated_signature): - if (hasattr(predictor, 'extended_signature')): + if hasattr(predictor, "extended_signature"): predictor.extended_signature = updated_signature - elif (hasattr(predictor, 'signature')): + elif hasattr(predictor, "signature"): predictor.signature = updated_signature - def compile(self, student, *, trainset, eval_kwargs): """student is a program that needs to be optimized, note that it may be zero-shot or already pre-optimized for demos != []""" module = student.deepcopy() evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs) total_calls = 0 - results_best = {id(p):{"depth": [], "max": [], "average": [], "min":[], "std": []} for p in module.predictors()} - results_latest = {id(p):{"depth": [], "max": [], "average": [], "min":[], "std": []} for p in module.predictors()} + results_best = { + id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors() + } + results_latest = { + id(p): {"depth": [], "max": [], "average": [], "min": [], "std": []} for p in module.predictors() + } if self.track_stats: import numpy as np - candidates = {} evaluated_candidates = defaultdict(dict) @@ -130,64 +149,85 @@ def compile(self, student, *, trainset, eval_kwargs): basic_prefix = None *_, last_key = self._get_signature(predictor).fields.keys() basic_instruction = self._get_signature(predictor).instructions - basic_prefix = self._get_signature(predictor).fields[last_key].json_schema_extra['prefix'] - if self.prompt_model: + basic_prefix = self._get_signature(predictor).fields[last_key].json_schema_extra["prefix"] + if self.prompt_model: with dspy.settings.context(lm=self.prompt_model): - instruct = dspy.Predict(BasicGenerateInstruction, n=self.breadth-1, temperature=self.init_temperature)(basic_instruction=basic_instruction) + instruct = dspy.Predict( + BasicGenerateInstruction, n=self.breadth - 1, temperature=self.init_temperature + )(basic_instruction=basic_instruction) else: - instruct = dspy.Predict(BasicGenerateInstruction, n=self.breadth-1, temperature=self.init_temperature)(basic_instruction=basic_instruction) + instruct = dspy.Predict( + BasicGenerateInstruction, n=self.breadth - 1, temperature=self.init_temperature + )(basic_instruction=basic_instruction) # Add in our initial prompt as a candidate as well instruct.completions.proposed_instruction.append(basic_instruction) instruct.completions.proposed_prefix_for_output_field.append(basic_prefix) candidates[id(predictor)] = instruct.completions evaluated_candidates[id(predictor)] = {} - - if self.verbose and self.prompt_model: print(f"{self.prompt_model.inspect_history(n=1)}") + + if self.verbose and self.prompt_model: + print(f"{self.prompt_model.inspect_history(n=1)}") latest_candidates = candidates all_candidates = candidates - + module_clone = module.deepcopy() # For each iteration in depth... - for d in range(self.depth): # TODO: fix this so that we eval the new batch of predictors with the new best followoing predictors + for d in range( + self.depth + ): # TODO: fix this so that we eval the new batch of predictors with the new best followoing predictors print(f"Iteration Depth: {d+1}/{self.depth}.") latest_scores = [] - + # Go through our module's predictors for p_i, (p_old, p_new) in enumerate(zip(module.predictors(), module_clone.predictors())): - candidates_ = latest_candidates[id(p_old)] # Use the most recently generated candidates for evaluation + candidates_ = latest_candidates[id(p_old)] # Use the most recently generated candidates for evaluation if len(module.predictors()) > 1: - candidates_ = all_candidates[id(p_old)] # Unless our program has multiple predictors, in which case we need to reevaluate all prompts with the new prompt(s) for the other predictor(s) + candidates_ = all_candidates[ + id(p_old) + ] # Unless our program has multiple predictors, in which case we need to reevaluate all prompts with the new prompt(s) for the other predictor(s) # For each candidate - for c_i, c in enumerate(candidates_): - # Get the candidate instruction and prefix - instruction, prefix = c.proposed_instruction.strip('"').strip(), c.proposed_prefix_for_output_field.strip('"').strip() - - # Set this new module with our instruction / prefix + for c_i, c in enumerate(candidates_): + # Get the candidate instruction and prefix + instruction, prefix = ( + c.proposed_instruction.strip('"').strip(), + c.proposed_prefix_for_output_field.strip('"').strip(), + ) + + # Set this new module with our instruction / prefix *_, last_key = self._get_signature(p_new).fields.keys() - updated_signature = self._get_signature(p_new) \ - .with_instructions(instruction) \ + updated_signature = ( + self._get_signature(p_new) + .with_instructions(instruction) .with_updated_fields(last_key, prefix=prefix) + ) self._set_signature(p_new, updated_signature) - # Score the instruction / prefix - if self.verbose: print("----------------") - for i,predictor in enumerate(module_clone.predictors()): - if self.verbose: print(f"Predictor {i+1}") + # Score the instruction / prefix + if self.verbose: + print("----------------") + for i, predictor in enumerate(module_clone.predictors()): + if self.verbose: + print(f"Predictor {i+1}") self._print_signature(predictor) - print(f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for Predictor {p_i+1} of {len(module.predictors())}.") + print( + f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for Predictor {p_i+1} of {len(module.predictors())}." + ) score = evaluate(module_clone, devset=trainset, **eval_kwargs) - if self.verbose and self.prompt_model: print(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}") + if self.verbose and self.prompt_model: + print(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}") total_calls += 1 - if self.verbose: print("----------------") + if self.verbose: + print("----------------") replace_entry = True - if self.verbose: print(f"(instruction, prefix) {(instruction, prefix)}") + if self.verbose: + print(f"(instruction, prefix) {(instruction, prefix)}") # if verbose: print(f"evaluated_candidates[id(p_old)] {evaluated_candidates[id(p_old)]}") - if ((instruction, prefix) in evaluated_candidates[id(p_old)]): + if (instruction, prefix) in evaluated_candidates[id(p_old)]: # if verbose: print(f"if evaluated_candidates[id(p_old)][(instruction, prefix)] {evaluated_candidates[id(p_old)][(instruction, prefix)]}") if evaluated_candidates[id(p_old)][(instruction, prefix)]["score"] >= score: replace_entry = False @@ -201,93 +241,107 @@ def compile(self, student, *, trainset, eval_kwargs): "prefix": prefix, "depth": d, } - - if (len(candidates_)-self.breadth <= c_i): + + if len(candidates_) - self.breadth <= c_i: latest_scores.append(score) if self.track_stats: results_latest[id(p_old)]["depth"].append(d) results_latest[id(p_old)]["max"].append(max(latest_scores)) - results_latest[id(p_old)]["average"].append(sum(latest_scores)/len(latest_scores)) + results_latest[id(p_old)]["average"].append(sum(latest_scores) / len(latest_scores)) results_latest[id(p_old)]["min"].append(min(latest_scores)) results_latest[id(p_old)]["std"].append(np.std(latest_scores)) - + # Now that we've evaluated the candidates, set this predictor to the best performing version # to ensure the next round of scores reflect the best possible version - best_candidate = max(evaluated_candidates[id(p_old)].values(), key=lambda candidate: candidate['score']) + best_candidate = max(evaluated_candidates[id(p_old)].values(), key=lambda candidate: candidate["score"]) *_, last_key = self._get_signature(p_old).fields.keys() - updated_signature = self._get_signature(p_new) \ - .with_instructions(best_candidate["instruction"]) \ + updated_signature = ( + self._get_signature(p_new) + .with_instructions(best_candidate["instruction"]) .with_updated_fields(last_key, prefix=best_candidate["prefix"]) + ) self._set_signature(p_new, updated_signature) - if self.verbose: print(f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\np: {best_candidate['prefix']}") - if self.verbose: print("Full predictor with update: ") - for i,predictor in enumerate(module_clone.predictors()): - if self.verbose: print(f"Predictor {i}") + if self.verbose: + print( + f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\np: {best_candidate['prefix']}" + ) + if self.verbose: + print("Full predictor with update: ") + for i, predictor in enumerate(module_clone.predictors()): + if self.verbose: + print(f"Predictor {i}") self._print_signature(predictor) - if d == self.depth-1: + if d == self.depth - 1: break - new_candidates = {} for p_base in module.predictors(): # Build Few-Shot Example of Optimized Prompts attempts = [] shortest_len = self.breadth - shortest_len = min(len(evaluated_candidates[id(p_base)]),shortest_len) + shortest_len = min(len(evaluated_candidates[id(p_base)]), shortest_len) best_predictors = list(evaluated_candidates[id(p_base)].values()) # best_predictors = evaluated_candidates[id(p_base)].values()[:] - best_predictors.sort(key=lambda x: x['score'], reverse=True) + best_predictors.sort(key=lambda x: x["score"], reverse=True) if self.track_stats: - scores = [x['score'] for x in best_predictors][:10] + scores = [x["score"] for x in best_predictors][:10] results_best[id(p_base)]["depth"].append(d) results_best[id(p_base)]["max"].append(max(scores)) - results_best[id(p_base)]["average"].append(sum(scores)/len(scores)) + results_best[id(p_base)]["average"].append(sum(scores) / len(scores)) results_best[id(p_base)]["min"].append(min(scores)) results_best[id(p_base)]["std"].append(np.std(scores)) - - for i in range(shortest_len-1,-1,-1): + + for i in range(shortest_len - 1, -1, -1): # breakpoint() attempts.append(f'Instruction #{shortest_len-i}: {best_predictors[i]["instruction"]}') attempts.append(f'Prefix #{shortest_len-i}: {best_predictors[i]["prefix"]}') attempts.append(f'Resulting Score #{shortest_len-i}: {best_predictors[i]["score"]}') - + # Generate next batch of potential prompts to optimize, with previous attempts as input - if self.prompt_model: + if self.prompt_model: with dspy.settings.context(lm=self.prompt_model): - instr = dspy.Predict(GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature)(attempted_instructions=attempts) + instr = dspy.Predict( + GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature + )(attempted_instructions=attempts) else: - instr = dspy.Predict(GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature)(attempted_instructions=attempts) + instr = dspy.Predict( + GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature + )(attempted_instructions=attempts) - if self.verbose and self.prompt_model: print(f"{self.prompt_model.inspect_history(n=1)}") + if self.verbose and self.prompt_model: + print(f"{self.prompt_model.inspect_history(n=1)}") # Get candidates for each predictor new_candidates[id(p_base)] = instr.completions all_candidates[id(p_base)].proposed_instruction.extend(instr.completions.proposed_instruction) - all_candidates[id(p_base)].proposed_prefix_for_output_field.extend(instr.completions.proposed_prefix_for_output_field) + all_candidates[id(p_base)].proposed_prefix_for_output_field.extend( + instr.completions.proposed_prefix_for_output_field + ) - if self.verbose and self.prompt_model: print(f"{self.prompt_model.inspect_history(n=1)}") + if self.verbose and self.prompt_model: + print(f"{self.prompt_model.inspect_history(n=1)}") latest_candidates = new_candidates - + candidates = [] for predictor in module.predictors(): candidates.extend(list(evaluated_candidates[id(predictor)].values())) if self.track_stats: best_predictors = list(evaluated_candidates[id(predictor)].values()) - best_predictors.sort(key=lambda x: x['score'], reverse=True) + best_predictors.sort(key=lambda x: x["score"], reverse=True) - scores = [x['score'] for x in best_predictors][:10] + scores = [x["score"] for x in best_predictors][:10] results_best[id(predictor)]["depth"].append(d) results_best[id(predictor)]["max"].append(max(scores)) - results_best[id(predictor)]["average"].append(sum(scores)/len(scores)) + results_best[id(predictor)]["average"].append(sum(scores) / len(scores)) results_best[id(predictor)]["min"].append(min(scores)) results_best[id(predictor)]["std"].append(np.std(scores)) # if verbose: print(f"candidates: {candidates}") - candidates.sort(key=lambda x: x['score'], reverse=True) + candidates.sort(key=lambda x: x["score"], reverse=True) candidates = self._drop_duplicates(candidates) @@ -298,4 +352,4 @@ def compile(self, student, *, trainset, eval_kwargs): best_program.results_best = results_best best_program.results_latest = results_latest - return best_program \ No newline at end of file + return best_program diff --git a/dspy/teleprompt/ensemble.py b/dspy/teleprompt/ensemble.py index 5e0db9bcac..f7cfb01607 100644 --- a/dspy/teleprompt/ensemble.py +++ b/dspy/teleprompt/ensemble.py @@ -6,12 +6,13 @@ TODO: The EnsembledProgram should actually imitate the structure of the individual programs (IF they are all compatible). This allows compiling with an ensemble program as a (singular) teacher. Basically the top majority-compatible trace will end up being used, if dspy.majority is the reduce_fn. """ + class Ensemble(Teleprompter): def __init__(self, *, reduce_fn=None, size=None, deterministic=False): """A common reduce_fn is dspy.majority.""" - + assert deterministic is False, "TODO: Implement example hashing for deterministic ensemble." - + self.reduce_fn = reduce_fn self.size = size self.deterministic = deterministic @@ -21,11 +22,12 @@ def compile(self, programs): reduce_fn = self.reduce_fn import dspy + class EnsembledProgram(dspy.Module): def __init__(self): super().__init__() self.programs = programs - + def forward(self, *args, **kwargs): programs = random.sample(self.programs, size) if size else self.programs outputs = [prog(*args, **kwargs) for prog in programs] diff --git a/dspy/teleprompt/finetune.py b/dspy/teleprompt/finetune.py index 8400fe7749..ca080a5c3d 100644 --- a/dspy/teleprompt/finetune.py +++ b/dspy/teleprompt/finetune.py @@ -18,11 +18,11 @@ # from dspy.evaluate.evaluate import Evaluate -if os.environ.get('DSP_NOTEBOOK_CACHEDIR'): - training_data_directory = os.path.join(os.environ.get('DSP_NOTEBOOK_CACHEDIR'), 'compiler') +if os.environ.get("DSP_NOTEBOOK_CACHEDIR"): + training_data_directory = os.path.join(os.environ.get("DSP_NOTEBOOK_CACHEDIR"), "compiler") print(training_data_directory) else: - training_data_directory = 'local_cache/compiler' + training_data_directory = "local_cache/compiler" if not os.path.exists(training_data_directory): os.makedirs(training_data_directory) @@ -54,20 +54,36 @@ def __init__(self, metric=None, teacher_settings={}, multitask=True): self.multitask = multitask metric = metric or (lambda *args: True) - self.teleprompter = BootstrapFewShot(metric=metric, - max_bootstrapped_demos=999999, - max_labeled_demos=0, # FIXME: TODO: Make this zero? or param, with default as 16 or 0? - teacher_settings=teacher_settings) - - - def compile(self, student, *, teacher=None, trainset, valset=None, - target='t5-large', bsize=12, accumsteps=1, lr=5e-5, epochs=1, bf16=False, int8=False, peft=False, path_prefix=None): - + self.teleprompter = BootstrapFewShot( + metric=metric, + max_bootstrapped_demos=999999, + max_labeled_demos=0, # FIXME: TODO: Make this zero? or param, with default as 16 or 0? + teacher_settings=teacher_settings, + ) + + def compile( + self, + student, + *, + teacher=None, + trainset, + valset=None, + target="t5-large", + bsize=12, + accumsteps=1, + lr=5e-5, + epochs=1, + bf16=False, + int8=False, + peft=False, + path_prefix=None, + ): # It's usually better to supply a few-shot teacher, rather than uncompiled module (the student). if teacher is None: - print("WARNING: Using a vanilla teacher. " - "Are you sure you want to use BootstrapFinetune without a compiled teacher?") - + print( + "WARNING: Using a vanilla teacher. " + "Are you sure you want to use BootstrapFinetune without a compiled teacher?" + ) teachers = teacher if isinstance(teacher, list) else [teacher] finetune_data = {} @@ -79,7 +95,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, # Prepare finetune pairs. for name, predictor in compiled.named_predictors(): - name_ = 'all' if multitask else name + name_ = "all" if multitask else name finetune_data[name_] = [] if name_ not in finetune_data else finetune_data[name_] for demo in predictor.demos: @@ -96,44 +112,47 @@ def compile(self, student, *, teacher=None, trainset, valset=None, random.Random(0).shuffle(finetune_data[name_]) print(name_, len(finetune_data[name_])) - # # Dump as files. - # + # finetune_paths = {} for name in finetune_data: data = finetune_data[name] - hashed_name = name + '.' + Hasher.hash(data) - output_path = os.path.join(training_data_directory, f'{hashed_name}.jsonl') + hashed_name = name + "." + Hasher.hash(data) + output_path = os.path.join(training_data_directory, f"{hashed_name}.jsonl") print(output_path) - with open(output_path, 'w') as f: + with open(output_path, "w") as f: for line in data: - f.write(ujson.dumps(line) + '\n') - + f.write(ujson.dumps(line) + "\n") + finetune_paths[name] = output_path - # # Train! # import string + compiler_config = { - 'save': ''.join(random.Random(time.time()).choices(string.ascii_uppercase + string.digits, k=13)), # https://stackoverflow.com/a/2257449/1493011 - 'peft': peft, - 'fp16': False, - 'bf16': bf16, - 'int8': int8, - 'fid': False, - 'rationale': False, - 'batch_size': bsize, - 'epochs': epochs, - 'gradient_accumulation_steps': accumsteps, # 2, - 'lr': lr, + "save": "".join( + random.Random(time.time()).choices(string.ascii_uppercase + string.digits, k=13) + ), # https://stackoverflow.com/a/2257449/1493011 + "peft": peft, + "fp16": False, + "bf16": bf16, + "int8": int8, + "fid": False, + "rationale": False, + "batch_size": bsize, + "epochs": epochs, + "gradient_accumulation_steps": accumsteps, # 2, + "lr": lr, } - compiler_config['save'] = os.path.join(path_prefix, compiler_config['save']) if path_prefix else compiler_config['save'] + compiler_config["save"] = ( + os.path.join(path_prefix, compiler_config["save"]) if path_prefix else compiler_config["save"] + ) from dsp.modules.finetuning import finetune_hf @@ -143,11 +162,11 @@ def compile(self, student, *, teacher=None, trainset, valset=None, for name in finetune_data: training_data_path = finetune_paths[name] compiler_config_ = dict(compiler_config) - compiler_config_['save'] = compiler_config['save'] + '.' + name + compiler_config_["save"] = compiler_config["save"] + "." + name best_ckpt_path = finetune_hf(training_data_path, target, compiler_config_) print(f"#> Best checkpoint path: {best_ckpt_path} for {name}") - finetune_models[name] = dsp.HFModel(model=target, checkpoint=best_ckpt_path) # best_ckpt_path + finetune_models[name] = dsp.HFModel(model=target, checkpoint=best_ckpt_path) # best_ckpt_path # # Set the LMs to the finetuned ones, per module @@ -158,7 +177,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, for (name, predictor), (name2, predictor2) in zip(compiled.named_predictors(), compiled2.named_predictors()): assert name == name2 - name = 'all' if multitask else name + name = "all" if multitask else name # TODO: FIXME: When we assign .lm, the Predict.forward will also set only_query=True. # This is correct for here but we may want to make it more explicitly restricted to finetuned models. @@ -166,5 +185,5 @@ def compile(self, student, *, teacher=None, trainset, valset=None, predictor2.lm = finetune_models[name] assert predictor2.demos == [] - + return compiled2 diff --git a/dspy/teleprompt/knn_fewshot.py b/dspy/teleprompt/knn_fewshot.py index b999447078..fe0d961b8a 100644 --- a/dspy/teleprompt/knn_fewshot.py +++ b/dspy/teleprompt/knn_fewshot.py @@ -17,8 +17,10 @@ def compile(self, student, *, teacher=None, trainset, valset=None): def forward_pass(*args, **kwargs): knn_trainset = self.KNN(**kwargs) few_shot_bootstrap = BootstrapFewShot() - compiled_program = few_shot_bootstrap.compile(student, teacher=teacher, trainset=knn_trainset, valset=valset) + compiled_program = few_shot_bootstrap.compile( + student, teacher=teacher, trainset=knn_trainset, valset=valset + ) return compiled_program(**kwargs) - + student_copy.forward = types.MethodType(forward_pass, student_copy) - return student_copy \ No newline at end of file + return student_copy diff --git a/dspy/teleprompt/mipro_optimizer.py b/dspy/teleprompt/mipro_optimizer.py index c9045ff79f..00f77fbccb 100644 --- a/dspy/teleprompt/mipro_optimizer.py +++ b/dspy/teleprompt/mipro_optimizer.py @@ -3,6 +3,7 @@ import sys import textwrap from collections import defaultdict +from typing import Any import optuna @@ -43,12 +44,16 @@ This information will be returned as attributes of the best program. """ + class BasicGenerateInstruction(Signature): """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" basic_instruction = dspy.InputField(desc="The initial instructions before optimization") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") - proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") + proposed_prefix_for_output_field = dspy.OutputField( + desc="The string at the end of the prompt, which will help the model start solving the task" + ) + class BasicGenerateInstructionWithDataObservations(Signature): """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. I will also give you some ``observations`` I have made about the dataset and task. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" @@ -56,54 +61,88 @@ class BasicGenerateInstructionWithDataObservations(Signature): basic_instruction = dspy.InputField(desc="The initial instructions before optimization") observations = dspy.InputField(desc="Observations about the dataset and task") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") - proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") + proposed_prefix_for_output_field = dspy.OutputField( + desc="The string at the end of the prompt, which will help the model start solving the task" + ) + class BasicGenerateInstructionWithExamples(dspy.Signature): - ("""You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Specifically, I will also provide you with the current ``basic instruction`` that is being used for this task. I will also provide you with some ``examples`` of the expected inputs and outputs. + """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Specifically, I will also provide you with the current ``basic instruction`` that is being used for this task. I will also provide you with some ``examples`` of the expected inputs and outputs. + + Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" + + # attempted_instructions = dspy.InputField(format=str, desc="Previously attempted task instructions, along with their resulting validation score, and an example of the instruction in use on a sample from our dataset.") + basic_instruction = dspy.InputField(desc="The initial instructions before optimization") + # examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task") + examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task") + proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") + proposed_prefix_for_output_field = dspy.OutputField( + desc="The string at the end of the prompt, which will help the model start solving the task" + ) -Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""") - # attempted_instructions = dspy.InputField(format=str, desc="Previously attempted task instructions, along with their resulting validation score, and an example of the instruction in use on a sample from our dataset.") - basic_instruction = dspy.InputField(desc="The initial instructions before optimization") - # examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task") - examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task") - proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") - proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") class BasicGenerateInstructionWithExamplesAndDataObservations(dspy.Signature): - ("""You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Specifically, I will give you some ``observations`` I have made about the dataset and task, along with some ``examples`` of the expected inputs and outputs. I will also provide you with the current ``basic instruction`` that is being used for this task. + """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Specifically, I will give you some ``observations`` I have made about the dataset and task, along with some ``examples`` of the expected inputs and outputs. I will also provide you with the current ``basic instruction`` that is being used for this task. + + Your task is to propose a new improved instruction and prefix for the output field that will lead a good language model to perform the task well. Don't be afraid to be creative.""" + + observations = dspy.InputField(desc="Observations about the dataset and task") + examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task") + basic_instruction = dspy.InputField(desc="The initial instructions before optimization") + proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") + proposed_prefix_for_output_field = dspy.OutputField( + desc="The string at the end of the prompt, which will help the model start solving the task" + ) -Your task is to propose a new improved instruction and prefix for the output field that will lead a good language model to perform the task well. Don't be afraid to be creative.""") - observations = dspy.InputField(desc="Observations about the dataset and task") - examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task") - basic_instruction = dspy.InputField(desc="The initial instructions before optimization") - proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") - proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") class ObservationSummarizer(dspy.Signature): - ("""Given a series of observations I have made about my dataset, please summarize them into a brief 2-3 sentence summary which highlights only the most important details.""") + """Given a series of observations I have made about my dataset, please summarize them into a brief 2-3 sentence summary which highlights only the most important details.""" + observations = dspy.InputField(desc="Observations I have made about my dataset") - summary = dspy.OutputField(desc="Two to Three sentence summary of only the most significant highlights of my observations") + summary = dspy.OutputField( + desc="Two to Three sentence summary of only the most significant highlights of my observations" + ) + class DatasetDescriptor(dspy.Signature): - ("""Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """ - """Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """ - """It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""") - + ( + """Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """ + """Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """ + """It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""" + ) + examples = dspy.InputField(desc="Sample data points from the dataset") observations = dspy.OutputField(desc="Somethings that holds true for most or all of the data you observed") + class DatasetDescriptorWithPriorObservations(dspy.Signature): - ("""Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """ - """I will also provide you with a few observations I have already made. Please add your own observations or if you feel the observations are comprehensive say 'COMPLETE' """ - """Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """ - """It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""") - + ( + """Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """ + """I will also provide you with a few observations I have already made. Please add your own observations or if you feel the observations are comprehensive say 'COMPLETE' """ + """Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """ + """It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""" + ) + examples = dspy.InputField(desc="Sample data points from the dataset") prior_observations = dspy.InputField(desc="Some prior observations I made about the data") - observations = dspy.OutputField(desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add") + observations = dspy.OutputField( + desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add" + ) + class MIPRO(Teleprompter): - def __init__(self, metric, prompt_model=None, task_model=None, teacher_settings={}, num_candidates=10, init_temperature=1.0, verbose=False, track_stats=True, view_data_batch_size=10): + def __init__( + self, + metric, + prompt_model=None, + task_model=None, + teacher_settings={}, + num_candidates=10, + init_temperature=1.0, + verbose=False, + track_stats=True, + view_data_batch_size=10, + ): self.num_candidates = num_candidates self.metric = metric self.init_temperature = init_temperature @@ -113,17 +152,22 @@ def __init__(self, metric, prompt_model=None, task_model=None, teacher_settings= self.track_stats = track_stats self.teacher_settings = teacher_settings self.view_data_batch_size = view_data_batch_size - + def _print_full_program(self, program): - for i,predictor in enumerate(program.predictors()): - if self.verbose: print(f"Predictor {i}") - if self.verbose: print(f"i: {self._get_signature(predictor).instructions}") + for i, predictor in enumerate(program.predictors()): + if self.verbose: + print(f"Predictor {i}") + if self.verbose: + print(f"i: {self._get_signature(predictor).instructions}") *_, last_field = self._get_signature(predictor).fields.values() - if self.verbose: print(f"p: {last_field.json_schema_extra['prefix']}") - if self.verbose: print("\n") - + if self.verbose: + print(f"p: {last_field.json_schema_extra['prefix']}") + if self.verbose: + print("\n") + def _print_model_history(self, model, n=1): - if self.verbose: print(f"Model ({model}) History:") + if self.verbose: + print(f"Model ({model}) History:") model.inspect_history(n=n) def _observe_data(self, trainset, max_iterations=10): @@ -134,8 +178,10 @@ def _observe_data(self, trainset, max_iterations=10): skips = 0 iterations = 0 for b in range(self.view_data_batch_size, len(trainset), self.view_data_batch_size): - upper_lim = min(len(trainset), b+self.view_data_batch_size) - output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)(prior_observations=observations, examples=(trainset[b:upper_lim].__repr__())) + upper_lim = min(len(trainset), b + self.view_data_batch_size) + output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)( + prior_observations=observations, examples=(trainset[b:upper_lim].__repr__()) + ) iterations += 1 if len(output["observations"]) >= 8 and output["observations"][:8].upper() == "COMPLETE": skips += 1 @@ -149,9 +195,8 @@ def _observe_data(self, trainset, max_iterations=10): summary = dspy.Predict(ObservationSummarizer, n=1, temperature=1.0)(observations=observations) return summary.summary - - def _create_example_string(self, fields, example): + def _create_example_string(self, fields, example): # Building the output string output = [] for field in fields: @@ -167,21 +212,30 @@ def _create_example_string(self, fields, example): output.append(field_str) # Joining all the field strings - return '\n'.join(output) + return "\n".join(output) def _get_signature(self, predictor): - if (hasattr(predictor, 'extended_signature')): + if hasattr(predictor, "extended_signature"): return predictor.extended_signature - elif (hasattr(predictor, 'signature')): + elif hasattr(predictor, "signature"): return predictor.signature - + return None + def _set_signature(self, predictor, updated_signature): - if (hasattr(predictor, 'extended_signature')): + if hasattr(predictor, "extended_signature"): predictor.extended_signature = updated_signature - elif (hasattr(predictor, 'signature')): + elif hasattr(predictor, "signature"): predictor.signature = updated_signature - - def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo_candidates, devset): + + def _generate_first_N_candidates( # noqa: N802 + self, + module: dspy.Module, + N: int, # noqa: N803 + view_data: bool, + view_examples: bool, + demo_candidates: dict, + devset, + ) -> tuple[dict, dict]: candidates = {} evaluated_candidates = defaultdict(dict) @@ -189,26 +243,25 @@ def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo # Create data observations self.observations = None with dspy.settings.context(lm=self.prompt_model): - self.observations = self._observe_data(devset).replace("Observations:","").replace("Summary:","") - + self.observations = self._observe_data(devset).replace("Observations:", "").replace("Summary:", "") + if view_examples: example_sets = {} for predictor in module.predictors(): # Get all augmented examples example_set = {} - all_sets_of_examples = demo_candidates[id(predictor)] # Get all generated sets of examples + all_sets_of_examples = demo_candidates[id(predictor)] # Get all generated sets of examples for example_set_i, set_of_examples in enumerate(all_sets_of_examples): - if example_set_i != 0: # Skip the no examples case - for example in set_of_examples: # Get each individual example in the set - if "augmented" in example.keys(): - if example["augmented"]: - if example_set_i not in example_set: - example_set[example_set_i] = [] - fields_to_use = signature_to_template(predictor.signature).fields - input_variable_names = list(self._get_signature(predictor).input_fields.keys()) - example_string = self._create_example_string(fields_to_use, example) - example_set[example_set_i].append(example_string) - example_sets[id(predictor)] = example_set + if example_set_i != 0: # Skip the no examples case + for example in set_of_examples: # Get each individual example in the set + if "augmented" in example and example["augmented"]: + if example_set_i not in example_set: + example_set[example_set_i] = [] + fields_to_use = signature_to_template(predictor.signature).fields + _input_variable_names = list(self._get_signature(predictor).input_fields.keys()) + example_string = self._create_example_string(fields_to_use, example) + example_set[example_set_i].append(example_string) + example_sets[id(predictor)] = example_set else: example_set[example_set_i] = [] example_sets[id(predictor)] = example_set @@ -223,7 +276,7 @@ def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo with dspy.settings.context(lm=self.prompt_model): # Data & Examples if view_data and view_examples: - if 1 not in example_sets[id(predictor)].keys(): + if 1 not in example_sets[id(predictor)]: raise ValueError("No examples found for the given predictor") instruct = None for i in range(1, self.num_candidates): @@ -239,15 +292,21 @@ def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo if not instruct: instruct = new_instruct else: - instruct.completions.proposed_instruction.extend(new_instruct.completions.proposed_instruction) - instruct.completions.proposed_prefix_for_output_field.extend(new_instruct.completions.proposed_prefix_for_output_field) + instruct.completions.proposed_instruction.extend( + new_instruct.completions.proposed_instruction + ) + instruct.completions.proposed_prefix_for_output_field.extend( + new_instruct.completions.proposed_prefix_for_output_field + ) # Just data - elif view_data: - instruct = dspy.Predict(BasicGenerateInstructionWithDataObservations, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction, observations=self.observations) + elif view_data: + instruct = dspy.Predict( + BasicGenerateInstructionWithDataObservations, n=N - 1, temperature=self.init_temperature + )(basic_instruction=basic_instruction, observations=self.observations) # Just examples - elif view_examples: + elif view_examples: instruct = None - for i in range(1,self.num_candidates): # Note: skip over the first example set which is empty + for i in range(1, self.num_candidates): # Note: skip over the first example set which is empty new_instruct = dspy.Predict( BasicGenerateInstructionWithExamples, n=1, @@ -259,33 +318,55 @@ def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo if not instruct: instruct = new_instruct else: - instruct.completions.proposed_instruction.extend(new_instruct.completions.proposed_instruction) - instruct.completions.proposed_prefix_for_output_field.extend(new_instruct.completions.proposed_prefix_for_output_field) + instruct.completions.proposed_instruction.extend( + new_instruct.completions.proposed_instruction, + ) + instruct.completions.proposed_prefix_for_output_field.extend( + new_instruct.completions.proposed_prefix_for_output_field, + ) # Neither - else: - instruct = dspy.Predict(BasicGenerateInstruction, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction) - + else: + instruct = dspy.Predict(BasicGenerateInstruction, n=N - 1, temperature=self.init_temperature)( + basic_instruction=basic_instruction + ) + # Add in our initial prompt as a candidate as well instruct.completions.proposed_instruction.insert(0, basic_instruction) instruct.completions.proposed_prefix_for_output_field.insert(0, basic_prefix) candidates[id(predictor)] = instruct.completions evaluated_candidates[id(predictor)] = {} - - if self.verbose: self._print_model_history(self.prompt_model) - + + if self.verbose: + self._print_model_history(self.prompt_model) + return candidates, evaluated_candidates - def compile(self, student, *, trainset, num_trials, max_bootstrapped_demos, max_labeled_demos, eval_kwargs, seed=42, view_data=True, view_examples=True, requires_permission_to_run=True): + def compile( + self, + student: dspy.Program, + *, + trainset: list[dspy.Example], + num_trials: int, + max_bootstrapped_demos: int, + max_labeled_demos: int, + eval_kwargs: dict[str, Any], + seed=42, + view_data=True, + view_examples=True, + requires_permission_to_run=True, + ) -> dspy.Program: # Define ANSI escape codes for colors - YELLOW = '\033[93m' - BLUE = '\033[94m' - BOLD = '\033[1m' - ENDC = '\033[0m' # Resets the color to default + YELLOW = "\033[93m" + BLUE = "\033[94m" + BOLD = "\033[1m" + ENDC = "\033[0m" # Resets the color to default random.seed(seed) - + estimated_task_model_calls_wo_module_calls = len(trainset) * num_trials # M * T * P - estimated_prompt_model_calls = 10 + self.num_candidates * len(student.predictors()) # num data summary calls + N * P + estimated_prompt_model_calls = 10 + self.num_candidates * len( + student.predictors() + ) # num data summary calls + N * P user_message = textwrap.dedent(f"""\ {YELLOW}{BOLD}WARNING: Projected Language Model (LM) Calls{ENDC} @@ -305,7 +386,7 @@ def compile(self, student, *, trainset, num_trials, max_bootstrapped_demos, max_ {YELLOW}- Reducing the number of trials (`num_trials`), the size of the trainset, or the number of LM calls in your program.{ENDC} {YELLOW}- Using a cheaper task model to optimize the prompt.{ENDC}""") - + user_confirmation_message = textwrap.dedent(f"""\ To proceed with the execution of this program, please confirm by typing {BLUE}'y'{ENDC} for yes or {BLUE}'n'{ENDC} for no. @@ -315,47 +396,54 @@ def compile(self, student, *, trainset, num_trials, max_bootstrapped_demos, max_ """) print(user_message) - - sys.stdout.flush() # Flush the output buffer to force the message to print + sys.stdout.flush() # Flush the output buffer to force the message to print - run=True + run = True if requires_permission_to_run: print(user_confirmation_message) user_input = input("Do you wish to continue? (y/n): ").strip().lower() - if user_input != 'y': + if user_input != "y": print("Compilation aborted by the user.") - run=False + run = False if run: # Set up program and evaluation function module = student.deepcopy() evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs) - + # In the case where the bootstrapped and labeled demos are set to 0, we'll stil bootstrap examples to use in our meta prompt - if max_bootstrapped_demos==0 and max_labeled_demos==0: #TODO: address case when max_bootstrapped alone is 0 - max_bootstrapped_demos_for_candidate_gen = 1 - max_labeled_demos_for_candidate_gen = 1 #TODO: this might only need to be 0 + if ( + max_bootstrapped_demos == 0 and max_labeled_demos == 0 + ): # TODO: address case when max_bootstrapped alone is 0 + max_bootstrapped_demos_for_candidate_gen = 1 + max_labeled_demos_for_candidate_gen = 1 # TODO: this might only need to be 0 else: - max_bootstrapped_demos_for_candidate_gen = max_bootstrapped_demos + max_bootstrapped_demos_for_candidate_gen = max_bootstrapped_demos max_labeled_demos_for_candidate_gen = max_labeled_demos # Generate N few shot example sets demo_candidates = {} for i in range(self.num_candidates): - if i == 0: # Story empty set of demos as default for index 0 + if i == 0: # Story empty set of demos as default for index 0 for module_p in module.predictors(): if id(module_p) not in demo_candidates: demo_candidates[id(module_p)] = [] demo_candidates[id(module_p)].append([]) else: - if self.verbose: print(f"Creating basic bootstrap: {i}/{self.num_candidates-1}") + if self.verbose: + print(f"Creating basic bootstrap: {i}/{self.num_candidates-1}") # Create a new basic bootstrap few - shot program . rng = random.Random(i) shuffled_trainset = trainset[:] # Create a copy of devset rng.shuffle(shuffled_trainset) # Shuffle the copy - tp = BootstrapFewShot(metric = self.metric, max_bootstrapped_demos=max_bootstrapped_demos_for_candidate_gen, max_labeled_demos=max_labeled_demos_for_candidate_gen, teacher_settings=self.teacher_settings) + tp = BootstrapFewShot( + metric=self.metric, + max_bootstrapped_demos=max_bootstrapped_demos_for_candidate_gen, + max_labeled_demos=max_labeled_demos_for_candidate_gen, + teacher_settings=self.teacher_settings, + ) candidate_program = tp.compile(student=module.deepcopy(), trainset=shuffled_trainset) # Store the candidate demos @@ -363,16 +451,18 @@ def compile(self, student, *, trainset, num_trials, max_bootstrapped_demos, max_ if id(module_p) not in demo_candidates: demo_candidates[id(module_p)] = [] demo_candidates[id(module_p)].append(candidate_p.demos) - + # Generate N candidate prompts - instruction_candidates, _ = self._generate_first_N_candidates(module, self.num_candidates, view_data, view_examples, demo_candidates, trainset) + instruction_candidates, _ = self._generate_first_N_candidates( + module, self.num_candidates, view_data, view_examples, demo_candidates, trainset + ) # Reset demo_candidates to None for our optimization if the user asked for no fewshot examples - if max_bootstrapped_demos==0 and max_labeled_demos==0: + if max_bootstrapped_demos == 0 and max_labeled_demos == 0: demo_candidates = None # Initialize variables to store the best program and its score - best_score = float('-inf') + best_score = float("-inf") best_program = None trial_num = 0 @@ -384,40 +474,54 @@ def objective(trial): nonlocal best_program, best_score, trial_num, trial_logs # Allow access to the outer variables candidate_program = baseline_program.deepcopy() - # Suggest the instruction to use for our predictor + # Suggest the instruction to use for our predictor print(f"Starting trial #{trial_num}") trial_logs[trial_num] = {} for p_old, p_new in zip(baseline_program.predictors(), candidate_program.predictors()): - # Get instruction candidates for our given predictor p_instruction_candidates = instruction_candidates[id(p_old)] - if demo_candidates: p_demo_candidates = demo_candidates[id(p_old)] + if demo_candidates: + p_demo_candidates = demo_candidates[id(p_old)] # Suggest the index of the instruction candidate to use in our trial - instruction_idx = trial.suggest_categorical(f"{id(p_old)}_predictor_instruction",range(len(p_instruction_candidates))) - if demo_candidates: demos_idx = trial.suggest_categorical(f"{id(p_old)}_predictor_demos",range(len(p_demo_candidates))) + instruction_idx = trial.suggest_categorical( + f"{id(p_old)}_predictor_instruction", range(len(p_instruction_candidates)) + ) + if demo_candidates: + demos_idx = trial.suggest_categorical( + f"{id(p_old)}_predictor_demos", range(len(p_demo_candidates)) + ) trial_logs[trial_num][f"{id(p_old)}_predictor_instruction"] = instruction_idx - if demo_candidates: trial_logs[trial_num][f"{id(p_old)}_predictor_demos"] = demos_idx + if demo_candidates: + trial_logs[trial_num][f"{id(p_old)}_predictor_demos"] = demos_idx - # Get the selected instruction candidate + # Get the selected instruction candidate selected_candidate = p_instruction_candidates[instruction_idx] selected_instruction = selected_candidate.proposed_instruction.strip('"').strip() selected_prefix = selected_candidate.proposed_prefix_for_output_field.strip('"').strip() # Use this candidates in our program *_, last_field = self._get_signature(p_new).fields.keys() - updated_signature = self._get_signature(p_new).with_instructions(selected_instruction).with_updated_fields(last_field, prefix=selected_prefix) + updated_signature = ( + self._get_signature(p_new) + .with_instructions(selected_instruction) + .with_updated_fields(last_field, prefix=selected_prefix) + ) self._set_signature(p_new, updated_signature) # Get the selected demos - if demo_candidates: selected_demos = p_demo_candidates[demos_idx] + if demo_candidates: + selected_demos = p_demo_candidates[demos_idx] # Use these demos in our program - if demo_candidates: p_new.demos = selected_demos - - if self.verbose: print("Evaling the following program:") - if self.verbose: self._print_full_program(candidate_program) + if demo_candidates: + p_new.demos = selected_demos + + if self.verbose: + print("Evaling the following program:") + if self.verbose: + self._print_full_program(candidate_program) trial_logs[trial_num]["program"] = candidate_program # Evaluate with the new prompts @@ -430,11 +534,13 @@ def objective(trial): end_index = min((i + 1) * batch_size, len(trainset)) split_trainset = trainset[start_index:end_index] split_score = evaluate(candidate_program, devset=split_trainset, display_table=0) - if self.verbose: print(f"{i}st split score: {split_score}") + if self.verbose: + print(f"{i}st split score: {split_score}") total_score += split_score * len(split_trainset) - curr_weighted_avg_score = total_score / min((i+1)*100,len(trainset)) - if self.verbose: print(f"curr average score: {curr_weighted_avg_score}") + curr_weighted_avg_score = total_score / min((i + 1) * 100, len(trainset)) + if self.verbose: + print(f"curr average score: {curr_weighted_avg_score}") trial.report(curr_weighted_avg_score, i) @@ -443,35 +549,38 @@ def objective(trial): print("Trial pruned.") trial_logs[trial_num]["score"] = curr_weighted_avg_score trial_logs[trial_num]["pruned"] = True - trial_num += 1 + trial_num += 1 raise optuna.TrialPruned() - - if self.verbose: print(f"Fully evaled score: {curr_weighted_avg_score}") - if self.verbose: self._print_model_history(self.task_model, n=1) + + if self.verbose: + print(f"Fully evaled score: {curr_weighted_avg_score}") + if self.verbose: + self._print_model_history(self.task_model, n=1) score = curr_weighted_avg_score - + trial_logs[trial_num]["score"] = curr_weighted_avg_score trial_logs[trial_num]["pruned"] = False - + # Update the best program if the current score is better if score > best_score: best_score = score best_program = candidate_program.deepcopy() - - trial_num += 1 + + trial_num += 1 return score return objective - # Run the trial + # Run the trial objective_function = create_objective(module, instruction_candidates, demo_candidates, evaluate, trainset) sampler = optuna.samplers.TPESampler(seed=seed) study = optuna.create_study(direction="maximize", sampler=sampler) - score = study.optimize(objective_function, n_trials=num_trials) + _score = study.optimize(objective_function, n_trials=num_trials) if best_program is not None and self.track_stats: best_program.trial_logs = trial_logs print(f"Returning {best_program} from continue_program") - return best_program \ No newline at end of file + return best_program + return None diff --git a/dspy/teleprompt/random_search.py b/dspy/teleprompt/random_search.py index b95bd6e461..7958bc4717 100644 --- a/dspy/teleprompt/random_search.py +++ b/dspy/teleprompt/random_search.py @@ -24,7 +24,19 @@ class BootstrapFewShotWithRandomSearch(Teleprompter): - def __init__(self, metric, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, num_candidate_programs=16, num_threads=6, max_errors=10, stop_at_score=None, metric_threshold=None): + def __init__( + self, + metric, + teacher_settings={}, + max_bootstrapped_demos=4, + max_labeled_demos=16, + max_rounds=1, + num_candidate_programs=16, + num_threads=6, + max_errors=10, + stop_at_score=None, + metric_threshold=None, + ): self.metric = metric self.teacher_settings = teacher_settings self.max_rounds = max_rounds @@ -64,17 +76,22 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None if seed == -3: # zero-shot program2 = student.reset_copy() - + elif seed == -2: # labels only teleprompter = LabeledFewShot(k=self.max_labeled_demos) program2 = teleprompter.compile(student, trainset=trainset2, sample=labeled_sample) - + elif seed == -1: # unshuffled few-shot - program = BootstrapFewShot(metric=self.metric, metric_threshold=self.metric_threshold, max_bootstrapped_demos=self.max_num_samples, - max_labeled_demos=self.max_labeled_demos, - teacher_settings=self.teacher_settings, max_rounds=self.max_rounds) + program = BootstrapFewShot( + metric=self.metric, + metric_threshold=self.metric_threshold, + max_bootstrapped_demos=self.max_num_samples, + max_labeled_demos=self.max_labeled_demos, + teacher_settings=self.teacher_settings, + max_rounds=self.max_rounds, + ) program2 = program.compile(student, teacher=teacher, trainset=trainset2) else: @@ -83,37 +100,47 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None random.Random(seed).shuffle(trainset2) size = random.Random(seed).randint(self.min_num_samples, self.max_num_samples) - teleprompter = BootstrapFewShot(metric=self.metric, metric_threshold=self.metric_threshold, max_bootstrapped_demos=size, - max_labeled_demos=self.max_labeled_demos, - teacher_settings=self.teacher_settings, - max_rounds=self.max_rounds) + teleprompter = BootstrapFewShot( + metric=self.metric, + metric_threshold=self.metric_threshold, + max_bootstrapped_demos=size, + max_labeled_demos=self.max_labeled_demos, + teacher_settings=self.teacher_settings, + max_rounds=self.max_rounds, + ) program2 = teleprompter.compile(student, teacher=teacher, trainset=trainset2) - evaluate = Evaluate(devset=self.valset, metric=self.metric, num_threads=self.num_threads, - max_errors=self.max_errors, display_table=False, display_progress=True) + evaluate = Evaluate( + devset=self.valset, + metric=self.metric, + num_threads=self.num_threads, + max_errors=self.max_errors, + display_table=False, + display_progress=True, + ) score, subscores = evaluate(program2, return_all_scores=True) all_subscores.append(subscores) ############ Assertion-aware Optimization ############ - if hasattr(program2, '_suggest_failures'): + if hasattr(program2, "_suggest_failures"): score = score - program2._suggest_failures * 0.2 - if hasattr(program2, '_assert_failures'): + if hasattr(program2, "_assert_failures"): score = 0 if program2._assert_failures > 0 else score ###################################################### - print('Score:', score, 'for set:', [len(predictor.demos) for predictor in program2.predictors()]) + print("Score:", score, "for set:", [len(predictor.demos) for predictor in program2.predictors()]) if len(scores) == 0 or score > max(scores): - print('New best score:', score, 'for seed', seed) + print("New best score:", score, "for seed", seed) best_program = program2 scores.append(score) print(f"Scores so far: {scores}") - print('Best score:', max(scores)) + print("Best score:", max(scores)) score_data.append((score, subscores, seed, program2)) @@ -125,8 +152,8 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None transposed_subscores = zip(*[subscores for _, subscores, *_ in top_3_scores if subscores]) avg_of_max_per_entry = sum(max(entry) for entry in transposed_subscores) / len(top_3_scores[0][1]) - print(f'Average of max per entry across top {k} scores: {avg_of_max_per_entry}') - + print(f"Average of max per entry across top {k} scores: {avg_of_max_per_entry}") + if self.stop_at_score is not None and score >= self.stop_at_score: print(f"Stopping early because score {score} is >= stop_at_score {self.stop_at_score}") break @@ -140,8 +167,6 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None return best_program - - # sample between 4 and 10 examples from traces # TODO: FIXME: The max number of demos should be determined in part by the LM's tokenizer + max_length. # This does require executing the program, or at least the predictor. diff --git a/dspy/teleprompt/signature_opt.py b/dspy/teleprompt/signature_opt.py index 9ba1da92c9..2b41f8de41 100644 --- a/dspy/teleprompt/signature_opt.py +++ b/dspy/teleprompt/signature_opt.py @@ -1,4 +1,3 @@ - from .copro_optimizer import COPRO """ @@ -32,10 +31,22 @@ These statistics will be returned as attributes of the best program. """ + class SignatureOptimizer(COPRO): - def __init__(self, prompt_model=None, metric=None, breadth=10, depth=3, init_temperature=1.4, verbose=False, track_stats=False): - print("\u001b[31m[WARNING] SignatureOptimizer has been deprecated and replaced with COPRO. SignatureOptimizer will be removed in a future release. \u001b[31m") + def __init__( + self, + prompt_model=None, + metric=None, + breadth=10, + depth=3, + init_temperature=1.4, + verbose=False, + track_stats=False, + ): + print( + "\u001b[31m[WARNING] SignatureOptimizer has been deprecated and replaced with COPRO. SignatureOptimizer will be removed in a future release. \u001b[31m" + ) super().__init__(prompt_model, metric, breadth, depth, init_temperature, verbose, track_stats) def compile(self, student, *, devset, eval_kwargs): - return super().compile(student, trainset=devset, eval_kwargs=eval_kwargs) \ No newline at end of file + return super().compile(student, trainset=devset, eval_kwargs=eval_kwargs) diff --git a/dspy/teleprompt/signature_opt_bayesian.py b/dspy/teleprompt/signature_opt_bayesian.py index b74c961511..f338644dc2 100644 --- a/dspy/teleprompt/signature_opt_bayesian.py +++ b/dspy/teleprompt/signature_opt_bayesian.py @@ -1,4 +1,3 @@ - from dspy.teleprompt.mipro_optimizer import MIPRO """ @@ -35,11 +34,60 @@ This information will be returned as attributes of the best program. """ + class BayesianSignatureOptimizer(MIPRO): - def __init__(self, prompt_model=None, task_model=None, teacher_settings={}, n=10, metric=None, init_temperature=1.0, verbose=False, track_stats=True, view_data_batch_size=10): - print("\u001b[31m[WARNING] BayesianSignatureOptimizer has been deprecated and replaced with MIPRO. BayesianSignatureOptimizer will be removed in a future release. \u001b[31m") + def __init__( + self, + prompt_model=None, + task_model=None, + teacher_settings={}, + n=10, + metric=None, + init_temperature=1.0, + verbose=False, + track_stats=True, + view_data_batch_size=10, + ): + print( + "\u001b[31m[WARNING] BayesianSignatureOptimizer has been deprecated and replaced with MIPRO. BayesianSignatureOptimizer will be removed in a future release. \u001b[31m" + ) - super().__init__(metric=metric,prompt_model=prompt_model, task_model=task_model, teacher_settings=teacher_settings,num_candidates=n,init_temperature=init_temperature,verbose=verbose,track_stats=track_stats,view_data_batch_size=view_data_batch_size) + super().__init__( + metric=metric, + prompt_model=prompt_model, + task_model=task_model, + teacher_settings=teacher_settings, + num_candidates=n, + init_temperature=init_temperature, + verbose=verbose, + track_stats=track_stats, + view_data_batch_size=view_data_batch_size, + ) - def compile(self, student, *, devset, max_bootstrapped_demos, max_labeled_demos, eval_kwargs, seed=42, optuna_trials_num, view_data=True, view_examples=True, requires_permission_to_run=False, num_trials=None): - return super().compile(student, trainset=devset, max_bootstrapped_demos=max_bootstrapped_demos, max_labeled_demos=max_labeled_demos, eval_kwargs=eval_kwargs, seed=seed, view_data=view_data, view_examples=view_examples, requires_permission_to_run=requires_permission_to_run, num_trials=optuna_trials_num) + def compile( + self, + student, + *, + devset, + max_bootstrapped_demos, + max_labeled_demos, + eval_kwargs, + seed=42, + optuna_trials_num, + view_data=True, + view_examples=True, + requires_permission_to_run=False, + num_trials=None, + ): + return super().compile( + student, + trainset=devset, + max_bootstrapped_demos=max_bootstrapped_demos, + max_labeled_demos=max_labeled_demos, + eval_kwargs=eval_kwargs, + seed=seed, + view_data=view_data, + view_examples=view_examples, + requires_permission_to_run=requires_permission_to_run, + num_trials=optuna_trials_num, + ) diff --git a/dspy/teleprompt/teleprompt.py b/dspy/teleprompt/teleprompt.py index ab3ae060c4..fe5b0b9500 100644 --- a/dspy/teleprompt/teleprompt.py +++ b/dspy/teleprompt/teleprompt.py @@ -1,6 +1,3 @@ - - - class Teleprompter: def __init__(self): pass diff --git a/dspy/teleprompt/teleprompt_optuna.py b/dspy/teleprompt/teleprompt_optuna.py index 501bd71fde..9c5290c91e 100644 --- a/dspy/teleprompt/teleprompt_optuna.py +++ b/dspy/teleprompt/teleprompt_optuna.py @@ -7,7 +7,16 @@ class BootstrapFewShotWithOptuna(Teleprompter): - def __init__(self, metric, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, num_candidate_programs=16, num_threads=6): + def __init__( + self, + metric, + teacher_settings={}, + max_bootstrapped_demos=4, + max_labeled_demos=16, + max_rounds=1, + num_candidate_programs=16, + num_threads=6, + ): self.metric = metric self.teacher_settings = teacher_settings self.max_rounds = max_rounds @@ -29,31 +38,42 @@ def __init__(self, metric, teacher_settings={}, max_bootstrapped_demos=4, max_la def objective(self, trial): program2 = self.student.reset_copy() - for (name, compiled_predictor), (_, program2_predictor) in zip(self.compiled_teleprompter.named_predictors(), program2.named_predictors()): + for (name, compiled_predictor), (_, program2_predictor) in zip( + self.compiled_teleprompter.named_predictors(), program2.named_predictors() + ): all_demos = compiled_predictor.demos demo_index = trial.suggest_int(f"demo_index_for_{name}", 0, len(all_demos) - 1) selected_demo = dict(all_demos[demo_index]) program2_predictor.demos = [selected_demo] - evaluate = Evaluate(devset=self.valset, metric=self.metric, num_threads=self.num_threads, - display_table=False, display_progress=True) + evaluate = Evaluate( + devset=self.valset, + metric=self.metric, + num_threads=self.num_threads, + display_table=False, + display_progress=True, + ) score, _ = evaluate(program2, return_all_scores=True) trial.set_user_attr("program", program2) return score - def compile(self, student, *, teacher=None, max_demos, trainset, valset=None): self.trainset = trainset self.valset = valset or trainset self.student = student.reset_copy() self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy() - teleprompter_optimize = BootstrapFewShot(metric=self.metric, max_bootstrapped_demos=max_demos, - max_labeled_demos=self.max_labeled_demos, - teacher_settings=self.teacher_settings, - max_rounds=self.max_rounds) - self.compiled_teleprompter = teleprompter_optimize.compile(self.student, teacher=self.teacher, trainset=self.trainset) - study = optuna.create_study(direction='maximize') + teleprompter_optimize = BootstrapFewShot( + metric=self.metric, + max_bootstrapped_demos=max_demos, + max_labeled_demos=self.max_labeled_demos, + teacher_settings=self.teacher_settings, + max_rounds=self.max_rounds, + ) + self.compiled_teleprompter = teleprompter_optimize.compile( + self.student, teacher=self.teacher, trainset=self.trainset + ) + study = optuna.create_study(direction="maximize") study.optimize(self.objective, n_trials=self.num_candidate_sets) best_program = study.trials[study.best_trial.number].user_attrs["program"] - print('Best score:', study.best_value) - print('Best program:', best_program) - return best_program \ No newline at end of file + print("Best score:", study.best_value) + print("Best program:", best_program) + return best_program diff --git a/dspy/teleprompt/vanilla.py b/dspy/teleprompt/vanilla.py index ced8312116..4f437f9c15 100644 --- a/dspy/teleprompt/vanilla.py +++ b/dspy/teleprompt/vanilla.py @@ -20,10 +20,11 @@ def compile(self, student, *, trainset, sample=True): if sample: predictor.demos = rng.sample(self.trainset, min(self.k, len(self.trainset))) else: - predictor.demos = self.trainset[:min(self.k, len(self.trainset))] + predictor.demos = self.trainset[: min(self.k, len(self.trainset))] return self.student - + + # NOTE: I believe templatev2 keeps rdemos as long as they have the last field. # This may change later, especially with the introduction of required vs optional fields. # NOTE: Since we're relying on downstream code to handle the demos, this sampling may be sub-sampled. From 4244c786896ace6c89664968511bb756696d5eb1 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Thu, 14 Mar 2024 22:21:29 -0700 Subject: [PATCH 2/2] formatting --- dspy/teleprompt/bootstrap.py | 8 ++--- dspy/teleprompt/copro_optimizer.py | 28 ++++++++++------ dspy/teleprompt/finetune.py | 4 +-- dspy/teleprompt/knn_fewshot.py | 5 ++- dspy/teleprompt/mipro_optimizer.py | 40 ++++++++++++++--------- dspy/teleprompt/signature_opt.py | 2 +- dspy/teleprompt/signature_opt_bayesian.py | 2 +- dspy/teleprompt/teleprompt_optuna.py | 4 +-- 8 files changed, 57 insertions(+), 36 deletions(-) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 86a1c0f362..de10a74a37 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -84,13 +84,13 @@ def _prepare_predictor_mappings(self): student, teacher = self.student, self.teacher assert len(student.predictors()) == len( - teacher.predictors() + teacher.predictors(), ), "Student and teacher must have the same number of predictors." for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()): assert name1 == name2, "Student and teacher must have the same program structure." assert predictor1.signature.equals( - predictor2.signature + predictor2.signature, ), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" assert id(predictor1) != id(predictor2), "Student and teacher must be different objects." @@ -195,11 +195,11 @@ def _bootstrap_one_example(self, example, round_idx=0): # TODO: Look closer into this. It's a bit tricky to reproduce. print(f"Failed to find predictor {predictor} in {self.predictor2name}.") print( - "Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells." + "Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.", ) print("Try restarting the notebook, or open an issue.") raise KeyError( - f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}." + f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.", ) from e name2traces[predictor_name].append(demo) diff --git a/dspy/teleprompt/copro_optimizer.py b/dspy/teleprompt/copro_optimizer.py index 3523c73804..a1a2062c09 100644 --- a/dspy/teleprompt/copro_optimizer.py +++ b/dspy/teleprompt/copro_optimizer.py @@ -39,7 +39,7 @@ class BasicGenerateInstruction(Signature): basic_instruction = dspy.InputField(desc="The initial instructions before optimization") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( - desc="The string at the end of the prompt, which will help the model start solving the task" + desc="The string at the end of the prompt, which will help the model start solving the task", ) @@ -51,7 +51,7 @@ class GenerateInstructionGivenAttempts(dspy.Signature): attempted_instructions = dspy.InputField(format=dsp.passages2text) proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( - desc="The string at the end of the prompt, which will help the model start solving the task" + desc="The string at the end of the prompt, which will help the model start solving the task", ) @@ -153,11 +153,15 @@ def compile(self, student, *, trainset, eval_kwargs): if self.prompt_model: with dspy.settings.context(lm=self.prompt_model): instruct = dspy.Predict( - BasicGenerateInstruction, n=self.breadth - 1, temperature=self.init_temperature + BasicGenerateInstruction, + n=self.breadth - 1, + temperature=self.init_temperature, )(basic_instruction=basic_instruction) else: instruct = dspy.Predict( - BasicGenerateInstruction, n=self.breadth - 1, temperature=self.init_temperature + BasicGenerateInstruction, + n=self.breadth - 1, + temperature=self.init_temperature, )(basic_instruction=basic_instruction) # Add in our initial prompt as a candidate as well instruct.completions.proposed_instruction.append(basic_instruction) @@ -175,7 +179,7 @@ def compile(self, student, *, trainset, eval_kwargs): # For each iteration in depth... for d in range( - self.depth + self.depth, ): # TODO: fix this so that we eval the new batch of predictors with the new best followoing predictors print(f"Iteration Depth: {d+1}/{self.depth}.") @@ -214,7 +218,7 @@ def compile(self, student, *, trainset, eval_kwargs): print(f"Predictor {i+1}") self._print_signature(predictor) print( - f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for Predictor {p_i+1} of {len(module.predictors())}." + f"At Depth {d+1}/{self.depth}, Evaluating Prompt Candidate #{c_i+1}/{len(candidates_)} for Predictor {p_i+1} of {len(module.predictors())}.", ) score = evaluate(module_clone, devset=trainset, **eval_kwargs) if self.verbose and self.prompt_model: @@ -264,7 +268,7 @@ def compile(self, student, *, trainset, eval_kwargs): self._set_signature(p_new, updated_signature) if self.verbose: print( - f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\np: {best_candidate['prefix']}" + f"Updating Predictor {id(p_old)} to:\ni: {best_candidate['instruction']}\np: {best_candidate['prefix']}", ) if self.verbose: print("Full predictor with update: ") @@ -305,11 +309,15 @@ def compile(self, student, *, trainset, eval_kwargs): if self.prompt_model: with dspy.settings.context(lm=self.prompt_model): instr = dspy.Predict( - GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature + GenerateInstructionGivenAttempts, + n=self.breadth, + temperature=self.init_temperature, )(attempted_instructions=attempts) else: instr = dspy.Predict( - GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature + GenerateInstructionGivenAttempts, + n=self.breadth, + temperature=self.init_temperature, )(attempted_instructions=attempts) if self.verbose and self.prompt_model: @@ -318,7 +326,7 @@ def compile(self, student, *, trainset, eval_kwargs): new_candidates[id(p_base)] = instr.completions all_candidates[id(p_base)].proposed_instruction.extend(instr.completions.proposed_instruction) all_candidates[id(p_base)].proposed_prefix_for_output_field.extend( - instr.completions.proposed_prefix_for_output_field + instr.completions.proposed_prefix_for_output_field, ) if self.verbose and self.prompt_model: diff --git a/dspy/teleprompt/finetune.py b/dspy/teleprompt/finetune.py index ca080a5c3d..ef79c1ca4a 100644 --- a/dspy/teleprompt/finetune.py +++ b/dspy/teleprompt/finetune.py @@ -82,7 +82,7 @@ def compile( if teacher is None: print( "WARNING: Using a vanilla teacher. " - "Are you sure you want to use BootstrapFinetune without a compiled teacher?" + "Are you sure you want to use BootstrapFinetune without a compiled teacher?", ) teachers = teacher if isinstance(teacher, list) else [teacher] @@ -136,7 +136,7 @@ def compile( compiler_config = { "save": "".join( - random.Random(time.time()).choices(string.ascii_uppercase + string.digits, k=13) + random.Random(time.time()).choices(string.ascii_uppercase + string.digits, k=13), ), # https://stackoverflow.com/a/2257449/1493011 "peft": peft, "fp16": False, diff --git a/dspy/teleprompt/knn_fewshot.py b/dspy/teleprompt/knn_fewshot.py index fe0d961b8a..ca9405268a 100644 --- a/dspy/teleprompt/knn_fewshot.py +++ b/dspy/teleprompt/knn_fewshot.py @@ -18,7 +18,10 @@ def forward_pass(*args, **kwargs): knn_trainset = self.KNN(**kwargs) few_shot_bootstrap = BootstrapFewShot() compiled_program = few_shot_bootstrap.compile( - student, teacher=teacher, trainset=knn_trainset, valset=valset + student, + teacher=teacher, + trainset=knn_trainset, + valset=valset, ) return compiled_program(**kwargs) diff --git a/dspy/teleprompt/mipro_optimizer.py b/dspy/teleprompt/mipro_optimizer.py index 00f77fbccb..c9e34189fd 100644 --- a/dspy/teleprompt/mipro_optimizer.py +++ b/dspy/teleprompt/mipro_optimizer.py @@ -51,7 +51,7 @@ class BasicGenerateInstruction(Signature): basic_instruction = dspy.InputField(desc="The initial instructions before optimization") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( - desc="The string at the end of the prompt, which will help the model start solving the task" + desc="The string at the end of the prompt, which will help the model start solving the task", ) @@ -62,7 +62,7 @@ class BasicGenerateInstructionWithDataObservations(Signature): observations = dspy.InputField(desc="Observations about the dataset and task") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( - desc="The string at the end of the prompt, which will help the model start solving the task" + desc="The string at the end of the prompt, which will help the model start solving the task", ) @@ -77,7 +77,7 @@ class BasicGenerateInstructionWithExamples(dspy.Signature): examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( - desc="The string at the end of the prompt, which will help the model start solving the task" + desc="The string at the end of the prompt, which will help the model start solving the task", ) @@ -91,7 +91,7 @@ class BasicGenerateInstructionWithExamplesAndDataObservations(dspy.Signature): basic_instruction = dspy.InputField(desc="The initial instructions before optimization") proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") proposed_prefix_for_output_field = dspy.OutputField( - desc="The string at the end of the prompt, which will help the model start solving the task" + desc="The string at the end of the prompt, which will help the model start solving the task", ) @@ -100,7 +100,7 @@ class ObservationSummarizer(dspy.Signature): observations = dspy.InputField(desc="Observations I have made about my dataset") summary = dspy.OutputField( - desc="Two to Three sentence summary of only the most significant highlights of my observations" + desc="Two to Three sentence summary of only the most significant highlights of my observations", ) @@ -126,7 +126,7 @@ class DatasetDescriptorWithPriorObservations(dspy.Signature): examples = dspy.InputField(desc="Sample data points from the dataset") prior_observations = dspy.InputField(desc="Some prior observations I made about the data") observations = dspy.OutputField( - desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add" + desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add", ) @@ -180,7 +180,8 @@ def _observe_data(self, trainset, max_iterations=10): for b in range(self.view_data_batch_size, len(trainset), self.view_data_batch_size): upper_lim = min(len(trainset), b + self.view_data_batch_size) output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)( - prior_observations=observations, examples=(trainset[b:upper_lim].__repr__()) + prior_observations=observations, + examples=(trainset[b:upper_lim].__repr__()), ) iterations += 1 if len(output["observations"]) >= 8 and output["observations"][:8].upper() == "COMPLETE": @@ -293,15 +294,17 @@ def _generate_first_N_candidates( # noqa: N802 instruct = new_instruct else: instruct.completions.proposed_instruction.extend( - new_instruct.completions.proposed_instruction + new_instruct.completions.proposed_instruction, ) instruct.completions.proposed_prefix_for_output_field.extend( - new_instruct.completions.proposed_prefix_for_output_field + new_instruct.completions.proposed_prefix_for_output_field, ) # Just data elif view_data: instruct = dspy.Predict( - BasicGenerateInstructionWithDataObservations, n=N - 1, temperature=self.init_temperature + BasicGenerateInstructionWithDataObservations, + n=N - 1, + temperature=self.init_temperature, )(basic_instruction=basic_instruction, observations=self.observations) # Just examples elif view_examples: @@ -327,7 +330,7 @@ def _generate_first_N_candidates( # noqa: N802 # Neither else: instruct = dspy.Predict(BasicGenerateInstruction, n=N - 1, temperature=self.init_temperature)( - basic_instruction=basic_instruction + basic_instruction=basic_instruction, ) # Add in our initial prompt as a candidate as well @@ -365,7 +368,7 @@ def compile( estimated_task_model_calls_wo_module_calls = len(trainset) * num_trials # M * T * P estimated_prompt_model_calls = 10 + self.num_candidates * len( - student.predictors() + student.predictors(), ) # num data summary calls + N * P user_message = textwrap.dedent(f"""\ @@ -454,7 +457,12 @@ def compile( # Generate N candidate prompts instruction_candidates, _ = self._generate_first_N_candidates( - module, self.num_candidates, view_data, view_examples, demo_candidates, trainset + module, + self.num_candidates, + view_data, + view_examples, + demo_candidates, + trainset, ) # Reset demo_candidates to None for our optimization if the user asked for no fewshot examples @@ -486,11 +494,13 @@ def objective(trial): # Suggest the index of the instruction candidate to use in our trial instruction_idx = trial.suggest_categorical( - f"{id(p_old)}_predictor_instruction", range(len(p_instruction_candidates)) + f"{id(p_old)}_predictor_instruction", + range(len(p_instruction_candidates)), ) if demo_candidates: demos_idx = trial.suggest_categorical( - f"{id(p_old)}_predictor_demos", range(len(p_demo_candidates)) + f"{id(p_old)}_predictor_demos", + range(len(p_demo_candidates)), ) trial_logs[trial_num][f"{id(p_old)}_predictor_instruction"] = instruction_idx if demo_candidates: diff --git a/dspy/teleprompt/signature_opt.py b/dspy/teleprompt/signature_opt.py index 2b41f8de41..e5c364daae 100644 --- a/dspy/teleprompt/signature_opt.py +++ b/dspy/teleprompt/signature_opt.py @@ -44,7 +44,7 @@ def __init__( track_stats=False, ): print( - "\u001b[31m[WARNING] SignatureOptimizer has been deprecated and replaced with COPRO. SignatureOptimizer will be removed in a future release. \u001b[31m" + "\u001b[31m[WARNING] SignatureOptimizer has been deprecated and replaced with COPRO. SignatureOptimizer will be removed in a future release. \u001b[31m", ) super().__init__(prompt_model, metric, breadth, depth, init_temperature, verbose, track_stats) diff --git a/dspy/teleprompt/signature_opt_bayesian.py b/dspy/teleprompt/signature_opt_bayesian.py index f338644dc2..51ca108227 100644 --- a/dspy/teleprompt/signature_opt_bayesian.py +++ b/dspy/teleprompt/signature_opt_bayesian.py @@ -49,7 +49,7 @@ def __init__( view_data_batch_size=10, ): print( - "\u001b[31m[WARNING] BayesianSignatureOptimizer has been deprecated and replaced with MIPRO. BayesianSignatureOptimizer will be removed in a future release. \u001b[31m" + "\u001b[31m[WARNING] BayesianSignatureOptimizer has been deprecated and replaced with MIPRO. BayesianSignatureOptimizer will be removed in a future release. \u001b[31m", ) super().__init__( diff --git a/dspy/teleprompt/teleprompt_optuna.py b/dspy/teleprompt/teleprompt_optuna.py index 9c5290c91e..ad2f6fe4ad 100644 --- a/dspy/teleprompt/teleprompt_optuna.py +++ b/dspy/teleprompt/teleprompt_optuna.py @@ -39,7 +39,7 @@ def __init__( def objective(self, trial): program2 = self.student.reset_copy() for (name, compiled_predictor), (_, program2_predictor) in zip( - self.compiled_teleprompter.named_predictors(), program2.named_predictors() + self.compiled_teleprompter.named_predictors(), program2.named_predictors(), ): all_demos = compiled_predictor.demos demo_index = trial.suggest_int(f"demo_index_for_{name}", 0, len(all_demos) - 1) @@ -69,7 +69,7 @@ def compile(self, student, *, teacher=None, max_demos, trainset, valset=None): max_rounds=self.max_rounds, ) self.compiled_teleprompter = teleprompter_optimize.compile( - self.student, teacher=self.teacher, trainset=self.trainset + self.student, teacher=self.teacher, trainset=self.trainset, ) study = optuna.create_study(direction="maximize") study.optimize(self.objective, n_trials=self.num_candidate_sets)