Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 43 additions & 26 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,24 @@


class BootstrapFewShot(Teleprompter):
def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5):
def __init__(
self,
metric=None,
metric_threshold=None,
teacher_settings={},
max_bootstrapped_demos=4,
max_labeled_demos=16,
max_rounds=1,
max_errors=5,
):
self.metric = metric
self.metric_threshold = metric_threshold
self.teacher_settings = teacher_settings

self.max_bootstrapped_demos = max_bootstrapped_demos
self.max_labeled_demos = max_labeled_demos
self.max_rounds = max_rounds
self.max_errors= max_errors
self.max_errors = max_errors
self.error_count = 0
self.error_lock = threading.Lock()

Expand All @@ -59,37 +68,41 @@ def compile(self, student, *, teacher=None, trainset, valset=None):
self.student._suggest_failures = 0

return self.student

def _prepare_student_and_teacher(self, student, teacher):
self.student = student.reset_copy()
self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy()

assert getattr(self.student, '_compiled', False) is False, "Student must be uncompiled."
assert getattr(self.student, "_compiled", False) is False, "Student must be uncompiled."

if self.max_labeled_demos and getattr(self.teacher, '_compiled', False) is False:
if self.max_labeled_demos and getattr(self.teacher, "_compiled", False) is False:
teleprompter = LabeledFewShot(k=self.max_labeled_demos)
self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset)

def _prepare_predictor_mappings(self):
name2predictor, predictor2name = {}, {}
student, teacher = self.student, self.teacher

assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors."
assert len(student.predictors()) == len(
teacher.predictors(),
), "Student and teacher must have the same number of predictors."

for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()):
assert name1 == name2, "Student and teacher must have the same program structure."
assert predictor1.signature.equals(predictor2.signature), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
assert predictor1.signature.equals(
predictor2.signature,
), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
assert id(predictor1) != id(predictor2), "Student and teacher must be different objects."

name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
predictor2name[id(predictor1)] = name1

# FIXME(shangyint): This is an ugly hack to bind traces of
# retry.module to retry
# if isinstance(predictor1, Retry):
# predictor2name[id(predictor1.module)] = name1

predictor2name[id(predictor2)] = name2
predictor2name[id(predictor2)] = name2

self.name2predictor = name2predictor
self.predictor2name = predictor2name
Expand All @@ -111,8 +124,8 @@ def _bootstrap(self, *, max_bootstraps=None):
if success:
bootstrapped[example_idx] = True

print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.')
print(f"Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.")

# Unbootstrapped training examples

self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped]
Expand All @@ -123,10 +136,10 @@ def _bootstrap(self, *, max_bootstraps=None):
# NOTE: Can't yet use evaluate because we need to trace *per example*
# evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12)
# score = evaluate(self.metric, display_table=False, display_progress=True)

def _bootstrap_one_example(self, example, round_idx=0):
name2traces = self.name2traces
teacher = self.teacher #.deepcopy()
teacher = self.teacher # .deepcopy()
predictor_cache = {}

try:
Expand All @@ -145,7 +158,7 @@ def _bootstrap_one_example(self, example, round_idx=0):

for name, predictor in teacher.named_predictors():
predictor.demos = predictor_cache[name]

if self.metric:
metric_val = self.metric(example, prediction, trace)
if self.metric_threshold:
Expand All @@ -162,13 +175,13 @@ def _bootstrap_one_example(self, example, round_idx=0):
current_error_count = self.error_count
if current_error_count >= self.max_errors:
raise e
print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.')
print(f"Failed to run or to evaluate example {example} with {self.metric} due to {e}.")

if success:
for step in trace:
predictor, inputs, outputs = step

if 'dspy_uuid' in example:
if "dspy_uuid" in example:
demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs)
else:
# TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case.
Expand All @@ -177,30 +190,34 @@ def _bootstrap_one_example(self, example, round_idx=0):
try:
predictor_name = self.predictor2name[id(predictor)]
except KeyError as e:
continue # FIXME: !
continue # FIXME: !

# TODO: Look closer into this. It's a bit tricky to reproduce.
print(f'Failed to find predictor {predictor} in {self.predictor2name}.')
print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.')
print('Try restarting the notebook, or open an issue.')
raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e
print(f"Failed to find predictor {predictor} in {self.predictor2name}.")
print(
"Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.",
)
print("Try restarting the notebook, or open an issue.")
raise KeyError(
f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.",
) from e

name2traces[predictor_name].append(demo)

return success

def _train(self):
rng = random.Random(0)
raw_demos = self.validation

for name, predictor in self.student.named_predictors():
augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos]
augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos]

sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos))
sample_size = max(0, sample_size)

raw_demos = rng.sample(raw_demos, sample_size)

if dspy.settings.release >= 20230928:
predictor.demos = raw_demos + augmented_demos
else:
Expand Down
Loading