3131
3232
3333class BootstrapFewShot (Teleprompter ):
34- def __init__ (self , metric = None , metric_threshold = None , teacher_settings = {}, max_bootstrapped_demos = 4 , max_labeled_demos = 16 , max_rounds = 1 , max_errors = 5 ):
34+ def __init__ (
35+ self ,
36+ metric = None ,
37+ metric_threshold = None ,
38+ teacher_settings = {},
39+ max_bootstrapped_demos = 4 ,
40+ max_labeled_demos = 16 ,
41+ max_rounds = 1 ,
42+ max_errors = 5 ,
43+ ):
3544 self .metric = metric
3645 self .metric_threshold = metric_threshold
3746 self .teacher_settings = teacher_settings
3847
3948 self .max_bootstrapped_demos = max_bootstrapped_demos
4049 self .max_labeled_demos = max_labeled_demos
4150 self .max_rounds = max_rounds
42- self .max_errors = max_errors
51+ self .max_errors = max_errors
4352 self .error_count = 0
4453 self .error_lock = threading .Lock ()
4554
@@ -59,37 +68,41 @@ def compile(self, student, *, teacher=None, trainset, valset=None):
5968 self .student ._suggest_failures = 0
6069
6170 return self .student
62-
71+
6372 def _prepare_student_and_teacher (self , student , teacher ):
6473 self .student = student .reset_copy ()
6574 self .teacher = teacher .deepcopy () if teacher is not None else student .reset_copy ()
6675
67- assert getattr (self .student , ' _compiled' , False ) is False , "Student must be uncompiled."
76+ assert getattr (self .student , " _compiled" , False ) is False , "Student must be uncompiled."
6877
69- if self .max_labeled_demos and getattr (self .teacher , ' _compiled' , False ) is False :
78+ if self .max_labeled_demos and getattr (self .teacher , " _compiled" , False ) is False :
7079 teleprompter = LabeledFewShot (k = self .max_labeled_demos )
7180 self .teacher = teleprompter .compile (self .teacher .reset_copy (), trainset = self .trainset )
7281
7382 def _prepare_predictor_mappings (self ):
7483 name2predictor , predictor2name = {}, {}
7584 student , teacher = self .student , self .teacher
7685
77- assert len (student .predictors ()) == len (teacher .predictors ()), "Student and teacher must have the same number of predictors."
86+ assert len (student .predictors ()) == len (
87+ teacher .predictors (),
88+ ), "Student and teacher must have the same number of predictors."
7889
7990 for (name1 , predictor1 ), (name2 , predictor2 ) in zip (student .named_predictors (), teacher .named_predictors ()):
8091 assert name1 == name2 , "Student and teacher must have the same program structure."
81- assert predictor1 .signature .equals (predictor2 .signature ), f"Student and teacher must have the same signatures. { type (predictor1 .signature )} != { type (predictor2 .signature )} "
92+ assert predictor1 .signature .equals (
93+ predictor2 .signature ,
94+ ), f"Student and teacher must have the same signatures. { type (predictor1 .signature )} != { type (predictor2 .signature )} "
8295 assert id (predictor1 ) != id (predictor2 ), "Student and teacher must be different objects."
8396
84- name2predictor [name1 ] = None # dict(student=predictor1, teacher=predictor2)
97+ name2predictor [name1 ] = None # dict(student=predictor1, teacher=predictor2)
8598 predictor2name [id (predictor1 )] = name1
8699
87100 # FIXME(shangyint): This is an ugly hack to bind traces of
88101 # retry.module to retry
89102 # if isinstance(predictor1, Retry):
90103 # predictor2name[id(predictor1.module)] = name1
91104
92- predictor2name [id (predictor2 )] = name2
105+ predictor2name [id (predictor2 )] = name2
93106
94107 self .name2predictor = name2predictor
95108 self .predictor2name = predictor2name
@@ -111,8 +124,8 @@ def _bootstrap(self, *, max_bootstraps=None):
111124 if success :
112125 bootstrapped [example_idx ] = True
113126
114- print (f' Bootstrapped { len (bootstrapped )} full traces after { example_idx + 1 } examples in round { round_idx } .' )
115-
127+ print (f" Bootstrapped { len (bootstrapped )} full traces after { example_idx + 1 } examples in round { round_idx } ." )
128+
116129 # Unbootstrapped training examples
117130
118131 self .validation = [x for idx , x in enumerate (self .trainset ) if idx not in bootstrapped ]
@@ -123,10 +136,10 @@ def _bootstrap(self, *, max_bootstraps=None):
123136 # NOTE: Can't yet use evaluate because we need to trace *per example*
124137 # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12)
125138 # score = evaluate(self.metric, display_table=False, display_progress=True)
126-
139+
127140 def _bootstrap_one_example (self , example , round_idx = 0 ):
128141 name2traces = self .name2traces
129- teacher = self .teacher # .deepcopy()
142+ teacher = self .teacher # .deepcopy()
130143 predictor_cache = {}
131144
132145 try :
@@ -145,7 +158,7 @@ def _bootstrap_one_example(self, example, round_idx=0):
145158
146159 for name , predictor in teacher .named_predictors ():
147160 predictor .demos = predictor_cache [name ]
148-
161+
149162 if self .metric :
150163 metric_val = self .metric (example , prediction , trace )
151164 if self .metric_threshold :
@@ -162,13 +175,13 @@ def _bootstrap_one_example(self, example, round_idx=0):
162175 current_error_count = self .error_count
163176 if current_error_count >= self .max_errors :
164177 raise e
165- print (f' Failed to run or to evaluate example { example } with { self .metric } due to { e } .' )
166-
178+ print (f" Failed to run or to evaluate example { example } with { self .metric } due to { e } ." )
179+
167180 if success :
168181 for step in trace :
169182 predictor , inputs , outputs = step
170183
171- if ' dspy_uuid' in example :
184+ if " dspy_uuid" in example :
172185 demo = Example (augmented = True , dspy_uuid = example .dspy_uuid , ** inputs , ** outputs )
173186 else :
174187 # TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case.
@@ -177,30 +190,34 @@ def _bootstrap_one_example(self, example, round_idx=0):
177190 try :
178191 predictor_name = self .predictor2name [id (predictor )]
179192 except KeyError as e :
180- continue # FIXME: !
193+ continue # FIXME: !
181194
182195 # TODO: Look closer into this. It's a bit tricky to reproduce.
183- print (f'Failed to find predictor { predictor } in { self .predictor2name } .' )
184- print ('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.' )
185- print ('Try restarting the notebook, or open an issue.' )
186- raise KeyError (f'Failed to find predictor { id (predictor )} { predictor } in { self .predictor2name } .' ) from e
196+ print (f"Failed to find predictor { predictor } in { self .predictor2name } ." )
197+ print (
198+ "Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells." ,
199+ )
200+ print ("Try restarting the notebook, or open an issue." )
201+ raise KeyError (
202+ f"Failed to find predictor { id (predictor )} { predictor } in { self .predictor2name } ." ,
203+ ) from e
187204
188205 name2traces [predictor_name ].append (demo )
189-
206+
190207 return success
191208
192209 def _train (self ):
193210 rng = random .Random (0 )
194211 raw_demos = self .validation
195212
196213 for name , predictor in self .student .named_predictors ():
197- augmented_demos = self .name2traces [name ][:self .max_bootstrapped_demos ]
198-
214+ augmented_demos = self .name2traces [name ][: self .max_bootstrapped_demos ]
215+
199216 sample_size = min (self .max_labeled_demos - len (augmented_demos ), len (raw_demos ))
200217 sample_size = max (0 , sample_size )
201218
202219 raw_demos = rng .sample (raw_demos , sample_size )
203-
220+
204221 if dspy .settings .release >= 20230928 :
205222 predictor .demos = raw_demos + augmented_demos
206223 else :
0 commit comments