Skip to content

Commit 649ba32

Browse files
authored
Merge pull request #654 from thomasahle/main
Formatting for new optimizers
2 parents cc8b193 + aee6baf commit 649ba32

File tree

12 files changed

+659
-333
lines changed

12 files changed

+659
-333
lines changed

dspy/teleprompt/bootstrap.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,24 @@
3131

3232

3333
class BootstrapFewShot(Teleprompter):
34-
def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5):
34+
def __init__(
35+
self,
36+
metric=None,
37+
metric_threshold=None,
38+
teacher_settings={},
39+
max_bootstrapped_demos=4,
40+
max_labeled_demos=16,
41+
max_rounds=1,
42+
max_errors=5,
43+
):
3544
self.metric = metric
3645
self.metric_threshold = metric_threshold
3746
self.teacher_settings = teacher_settings
3847

3948
self.max_bootstrapped_demos = max_bootstrapped_demos
4049
self.max_labeled_demos = max_labeled_demos
4150
self.max_rounds = max_rounds
42-
self.max_errors= max_errors
51+
self.max_errors = max_errors
4352
self.error_count = 0
4453
self.error_lock = threading.Lock()
4554

@@ -59,37 +68,41 @@ def compile(self, student, *, teacher=None, trainset, valset=None):
5968
self.student._suggest_failures = 0
6069

6170
return self.student
62-
71+
6372
def _prepare_student_and_teacher(self, student, teacher):
6473
self.student = student.reset_copy()
6574
self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy()
6675

67-
assert getattr(self.student, '_compiled', False) is False, "Student must be uncompiled."
76+
assert getattr(self.student, "_compiled", False) is False, "Student must be uncompiled."
6877

69-
if self.max_labeled_demos and getattr(self.teacher, '_compiled', False) is False:
78+
if self.max_labeled_demos and getattr(self.teacher, "_compiled", False) is False:
7079
teleprompter = LabeledFewShot(k=self.max_labeled_demos)
7180
self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset)
7281

7382
def _prepare_predictor_mappings(self):
7483
name2predictor, predictor2name = {}, {}
7584
student, teacher = self.student, self.teacher
7685

77-
assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors."
86+
assert len(student.predictors()) == len(
87+
teacher.predictors(),
88+
), "Student and teacher must have the same number of predictors."
7889

7990
for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()):
8091
assert name1 == name2, "Student and teacher must have the same program structure."
81-
assert predictor1.signature.equals(predictor2.signature), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
92+
assert predictor1.signature.equals(
93+
predictor2.signature,
94+
), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
8295
assert id(predictor1) != id(predictor2), "Student and teacher must be different objects."
8396

84-
name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
97+
name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
8598
predictor2name[id(predictor1)] = name1
8699

87100
# FIXME(shangyint): This is an ugly hack to bind traces of
88101
# retry.module to retry
89102
# if isinstance(predictor1, Retry):
90103
# predictor2name[id(predictor1.module)] = name1
91104

92-
predictor2name[id(predictor2)] = name2
105+
predictor2name[id(predictor2)] = name2
93106

94107
self.name2predictor = name2predictor
95108
self.predictor2name = predictor2name
@@ -111,8 +124,8 @@ def _bootstrap(self, *, max_bootstraps=None):
111124
if success:
112125
bootstrapped[example_idx] = True
113126

114-
print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.')
115-
127+
print(f"Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.")
128+
116129
# Unbootstrapped training examples
117130

118131
self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped]
@@ -123,10 +136,10 @@ def _bootstrap(self, *, max_bootstraps=None):
123136
# NOTE: Can't yet use evaluate because we need to trace *per example*
124137
# evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12)
125138
# score = evaluate(self.metric, display_table=False, display_progress=True)
126-
139+
127140
def _bootstrap_one_example(self, example, round_idx=0):
128141
name2traces = self.name2traces
129-
teacher = self.teacher #.deepcopy()
142+
teacher = self.teacher # .deepcopy()
130143
predictor_cache = {}
131144

132145
try:
@@ -145,7 +158,7 @@ def _bootstrap_one_example(self, example, round_idx=0):
145158

146159
for name, predictor in teacher.named_predictors():
147160
predictor.demos = predictor_cache[name]
148-
161+
149162
if self.metric:
150163
metric_val = self.metric(example, prediction, trace)
151164
if self.metric_threshold:
@@ -162,13 +175,13 @@ def _bootstrap_one_example(self, example, round_idx=0):
162175
current_error_count = self.error_count
163176
if current_error_count >= self.max_errors:
164177
raise e
165-
print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.')
166-
178+
print(f"Failed to run or to evaluate example {example} with {self.metric} due to {e}.")
179+
167180
if success:
168181
for step in trace:
169182
predictor, inputs, outputs = step
170183

171-
if 'dspy_uuid' in example:
184+
if "dspy_uuid" in example:
172185
demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs)
173186
else:
174187
# TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case.
@@ -177,30 +190,34 @@ def _bootstrap_one_example(self, example, round_idx=0):
177190
try:
178191
predictor_name = self.predictor2name[id(predictor)]
179192
except KeyError as e:
180-
continue # FIXME: !
193+
continue # FIXME: !
181194

182195
# TODO: Look closer into this. It's a bit tricky to reproduce.
183-
print(f'Failed to find predictor {predictor} in {self.predictor2name}.')
184-
print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.')
185-
print('Try restarting the notebook, or open an issue.')
186-
raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e
196+
print(f"Failed to find predictor {predictor} in {self.predictor2name}.")
197+
print(
198+
"Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.",
199+
)
200+
print("Try restarting the notebook, or open an issue.")
201+
raise KeyError(
202+
f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.",
203+
) from e
187204

188205
name2traces[predictor_name].append(demo)
189-
206+
190207
return success
191208

192209
def _train(self):
193210
rng = random.Random(0)
194211
raw_demos = self.validation
195212

196213
for name, predictor in self.student.named_predictors():
197-
augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos]
198-
214+
augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos]
215+
199216
sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos))
200217
sample_size = max(0, sample_size)
201218

202219
raw_demos = rng.sample(raw_demos, sample_size)
203-
220+
204221
if dspy.settings.release >= 20230928:
205222
predictor.demos = raw_demos + augmented_demos
206223
else:

0 commit comments

Comments
 (0)