Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,36 @@

logger = logging.getLogger(__name__)


class BootstrapFewShot(Teleprompter):
def __init__(
self,
metric=None,
metric_threshold=None,
teacher_settings: Optional[Dict]=None,
teacher_settings: Optional[Dict] = None,
max_bootstrapped_demos=4,
max_labeled_demos=16,
max_rounds=1,
max_errors=5,
):
"""
A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt.
"""A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt.
These demos come from a combination of labeled examples in the training set, and bootstrapped demos.

Args:
metric: Callable
A function that compares an expected value and predicted value, outputting the result of that comparison.
metric_threshold: optional float, default `None`
If the metric yields a numerical value, then check it against this threshold when
deciding whether or not to accept a bootstrap example.
teacher_settings: dict, optional
Settings for the `teacher` model.
max_bootstrapped_demos: int, default 4
Maximum number of bootstrapped demonstrations to include
max_labeled_demos: int, default 16
Maximum number of labeled demonstrations to include.
max_rounds: int, default 1
Number of iterations to attempt generating the required bootstrap examples. If unsuccessful after `max_rounds`, the program ends.
max_errors: int, default 5
Maximum number of errors until program ends.
metric (Callable): A function that compares an expected value and predicted value,
outputting the result of that comparison.
metric_threshold (float, optional): If the metric yields a numerical value, then check it
against this threshold when deciding whether or not to accept a bootstrap example.
Defaults to None.
teacher_settings (dict, optional): Settings for the `teacher` model.
Defaults to None.
max_bootstrapped_demos (int): Maximum number of bootstrapped demonstrations to include.
Defaults to 4.
max_labeled_demos (int): Maximum number of labeled demonstrations to include.
Defaults to 16.
max_rounds (int): Number of iterations to attempt generating the required bootstrap
examples. If unsuccessful after `max_rounds`, the program ends. Defaults to 1.
max_errors (int): Maximum number of errors until program ends. Defaults to 5.
"""
self.metric = metric
self.metric_threshold = metric_threshold
Expand Down Expand Up @@ -117,9 +116,10 @@ def _prepare_predictor_mappings(self):
if hasattr(predictor1.signature, "equals"):
assert predictor1.signature.equals(
predictor2.signature,
), (f"Student and teacher must have the same signatures. "
), (
f"Student and teacher must have the same signatures. "
f"{type(predictor1.signature)} != {type(predictor2.signature)}"
)
)
else:
# fallback in case if .equals is not implemented (e.g. dsp.Prompt)
assert predictor1.signature == predictor2.signature, (
Expand Down Expand Up @@ -149,7 +149,8 @@ def _bootstrap(self, *, max_bootstraps=None):
self.name2traces = {name: [] for name in self.name2predictor}

for example_idx, example in enumerate(tqdm.tqdm(self.trainset)):
if len(bootstrapped) >= max_bootstraps: break
if len(bootstrapped) >= max_bootstraps:
break

for round_idx in range(self.max_rounds):
bootstrap_attempts += 1
Expand All @@ -175,8 +176,8 @@ def _bootstrap(self, *, max_bootstraps=None):
# score = evaluate(self.metric, display_table=False, display_progress=True)

def _bootstrap_one_example(self, example, round_idx=0):
name2traces = {} #self.name2traces
teacher = self.teacher # .deepcopy()
name2traces = {}
teacher = self.teacher
predictor_cache = {}

try:
Expand Down Expand Up @@ -235,10 +236,11 @@ def _bootstrap_one_example(self, example, round_idx=0):

name2traces[predictor_name] = name2traces.get(predictor_name, [])
name2traces[predictor_name].append(demo)

# Update the traces
for name, demos in name2traces.items():
from datasets.fingerprint import Hasher

# If there are multiple traces for the same predictor in the sample example,
# sample 50/50 from the first N-1 traces or the last trace.
if len(demos) > 1:
Expand Down
88 changes: 57 additions & 31 deletions dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Union

import dspy
Expand All @@ -12,12 +12,10 @@
from dspy.primitives.program import Program
from dspy.teleprompt.teleprompt import Teleprompter


logger = logging.getLogger(__name__)


class FinetuneTeleprompter(Teleprompter):

def __init__(
self,
train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None,
Expand All @@ -41,23 +39,25 @@ def __init__(
train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None,
adapter: Optional[Union[Adapter, Dict[LM, Adapter]]] = None,
exclude_demos: bool = False,
num_threads: int = 6
num_threads: int = 6,
):
# TODO(feature): Inputs train_kwargs (a dict with string keys) and
# adapter (Adapter) can depend on the LM they are used with. We are
# takingthese as parameters for the time being. However, they can be
# takingthese as parameters for the time being. However, they can be
# attached to LMs themselves -- an LM could know which adapter it should
# be used with along with the train_kwargs. This will lead the only
# required argument for LM.finetune() to be the train dataset.

super().__init__(train_kwargs=train_kwargs)
self.metric = metric
self.multitask = multitask
self.adapter: Dict[LM, Adapter] = self.convert_to_lm_dict(adapter)
self.exclude_demos = exclude_demos
self.num_threads = num_threads

def compile(self, student: Program, trainset: List[Example], teacher: Optional[Union[Program, List[Program]]] = None) -> Program:

def compile(
self, student: Program, trainset: List[Example], teacher: Optional[Union[Program, List[Program]]] = None
) -> Program:
# TODO: Print statements can be converted to logger.info if we ensure
# that the default DSPy logger logs info level messages in notebook
# environments.
Expand All @@ -71,24 +71,41 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[U
teachers = [prepare_teacher(student, t) for t in teachers]
for t in teachers:
set_missing_predictor_lms(t)
trace_data += bootstrap_trace_data(program=t, dataset=trainset, metric=self.metric, num_threads=self.num_threads)
trace_data += bootstrap_trace_data(
program=t, dataset=trainset, metric=self.metric, num_threads=self.num_threads
)

logger.info("Preparing the train data...")
key_to_data = {}
for pred_ind, pred in enumerate(student.predictors()):
data_pred_ind = None if self.multitask else pred_ind
training_key = (pred.lm, data_pred_ind)
if training_key not in key_to_data:
train_data, data_format = self._prepare_finetune_data(trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind)
train_data, data_format = self._prepare_finetune_data(
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
)
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
finetune_kwargs = dict(lm=pred.lm, train_data=train_data, train_data_format=data_format, train_kwargs=self.train_kwargs[pred.lm])
finetune_kwargs = dict(
lm=pred.lm,
train_data=train_data,
train_data_format=data_format,
train_kwargs=self.train_kwargs[pred.lm],
)
key_to_data[training_key] = finetune_kwargs

logger.info("Starting LM fine-tuning...")
# TODO(feature): We could run batches of fine-tuning jobs in sequence
# to avoid exceeding the number of threads.
err = f"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: {self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will be equal to the number of predictors in the student program. If the `multitask` flag is set to True, the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning jobs will be less than or equal to the number of predictors."
assert len(key_to_data) <= self.num_threads, err
if len(key_to_data) > self.num_threads:
raise ValueError(
"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning "
f"jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: "
f"{self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will "
"be equal to the number of predictors in the student program. If the `multitask` flag is set to True, "
"the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of "
"unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning "
"jobs will be less than or equal to the number of predictors."
)
logger.info(f"{len(key_to_data)} fine-tuning job(s) to start")
key_to_lm = self.finetune_lms(key_to_data)

Expand All @@ -98,10 +115,10 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[U
training_key = (pred.lm, data_pred_ind)
pred.lm = key_to_lm[training_key]
# TODO: What should the correct behavior be here? Should
# BootstrapFinetune modify the prompt demos according to the
# BootstrapFinetune modify the prompt demos according to the
# train data?
pred.demos = [] if self.exclude_demos else pred.demos

logger.info("BootstrapFinetune has finished compiling the student program")
student._compiled = True
return student
Expand All @@ -120,10 +137,13 @@ def finetune_lms(finetune_dict) -> Dict[Any, LM]:
# up resources for fine-tuning. This might mean introducing a new
# provider method (e.g. prepare_for_finetune) that can be called
# before fine-tuning is started.
logger.info("Calling lm.kill() on the LM to be fine-tuned to free up resources. This won't have any effect if the LM is not running.")
logger.info(
"Calling lm.kill() on the LM to be fine-tuned to free up resources. This won't have any effect if the "
"LM is not running."
)
lm.kill()
key_to_job[key] = lm.finetune(**finetune_kwargs)

key_to_lm = {}
for ind, (key, job) in enumerate(key_to_job.items()):
key_to_lm[key] = job.result()
Expand All @@ -143,13 +163,16 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
adapter = self.adapter[lm] or lm.infer_adapter()
data_format = infer_data_format(adapter)
for item in trace_data:
for pred_ind, _ in enumerate(item['trace']):
for pred_ind, _ in enumerate(item["trace"]):
include_data = pred_ind is None or pred_ind == pred_ind
if include_data:
call_data = build_call_data_from_trace(trace=item['trace'], pred_ind=pred_ind, adapter=adapter, exclude_demos=self.exclude_demos)
call_data = build_call_data_from_trace(
trace=item["trace"], pred_ind=pred_ind, adapter=adapter, exclude_demos=self.exclude_demos
)
data.append(call_data)

import random

random.Random(0).shuffle(data)

return data, data_format
Expand Down Expand Up @@ -189,8 +212,11 @@ def bootstrap_trace_data(
# Return a list of dicts with the following keys:
# example_ind, example, prediction, trace, and score (if metric != None)
evaluator = Evaluate(
devset=dataset, num_threads=num_threads, display_progress=True, return_outputs=True,
provide_traceback=True # TODO(check with team)
devset=dataset,
num_threads=num_threads,
display_progress=True,
return_outputs=True,
provide_traceback=True, # TODO(check with team)
)

def wrapped_metric(example, prediction, trace=None):
Expand Down Expand Up @@ -286,11 +312,10 @@ def assert_structural_equivalency(program1: object, program2: object):

pzip = zip(program1.named_predictors(), program2.named_predictors())
for ind, ((name1, pred1), (name2, pred2)) in enumerate(pzip):
err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index {ind}: '{name1}' != '{name2}'"
err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index {ind}: '{name1}' != '{name2}'"
assert name1 == name2, err
assert isinstance(pred1, Predict)
assert isinstance(pred2, Predict)
# assert pred1.signature.equals(pred2.signature)


def assert_no_shared_predictor(program1: Program, program2: Program):
Expand All @@ -303,17 +328,18 @@ def assert_no_shared_predictor(program1: Program, program2: Program):
assert not shared_ids, err


def get_unique_lms(program: Program) -> List[LM]:
lms = [pred.lm for pred in program.predictors()]
lms = list(set(lms))
return lms
def get_unique_lms(program: Program) -> List[LM]:
lms = [pred.lm for pred in program.predictors()]
return list(set(lms))


def launch_lms(program: Program):
lms = get_unique_lms(program)
for lm in lms:
lm.launch()

def kill_lms(program: Program):
lms = get_unique_lms(program)
for lm in lms:

def kill_lms(program: Program):
lms = get_unique_lms(program)
for lm in lms:
lm.kill()
Loading
Loading