diff --git a/dsp/__init__.py b/dsp/__init__.py index 4f98315f13..86b103f10c 100644 --- a/dsp/__init__.py +++ b/dsp/__init__.py @@ -1,7 +1,7 @@ -from .modules import * -from .primitives import * -from .templates import * -from .utils import settings +from .modules import * # noqa +from .primitives import * # noqa +from .adapters import * # noqa +from .utils import settings # noqa """ TODO: diff --git a/dsp/adapters/__init__.py b/dsp/adapters/__init__.py new file mode 100644 index 0000000000..d1c8c759cf --- /dev/null +++ b/dsp/adapters/__init__.py @@ -0,0 +1,4 @@ +from .base_template import * # noqa +from .template import * # noqa +from .experimental_adapter import * # noqa +from .utils import * # noqa \ No newline at end of file diff --git a/dsp/templates/template_v3.py b/dsp/adapters/base_template.py similarity index 91% rename from dsp/templates/template_v3.py rename to dsp/adapters/base_template.py index 605705c153..642f491831 100644 --- a/dsp/templates/template_v3.py +++ b/dsp/adapters/base_template.py @@ -1,6 +1,9 @@ +from collections import namedtuple from typing import Callable -from dsp.templates import Field, TemplateV2, format_answers, passages2text +from .utils import format_answers, passages2text + +Field = namedtuple("Field", "name separator input_variable output_variable description") class Type: @@ -19,7 +22,7 @@ def __eq__(self, __value: object) -> bool: return isinstance(__value, Type) and self.__dict__ == __value.__dict__ -class Template(TemplateV2): +class BaseTemplate: """A template datatype that represents the structure of communicate with the LM.""" def __init__(self, instructions: str, **kwargs): @@ -61,9 +64,7 @@ def __eq__(self, other): v1, v2 = self.kwargs[k], other.kwargs[k] if not v1 == v2: print(k, v1, v2) - - # print("here?", self.instructions == other.instructions, self.kwargs == other.kwargs) return self.instructions == other.instructions and self.kwargs == other.kwargs def __str__(self) -> str: diff --git a/dsp/adapters/experimental_adapter.py b/dsp/adapters/experimental_adapter.py new file mode 100644 index 0000000000..6e2eb3a335 --- /dev/null +++ b/dsp/adapters/experimental_adapter.py @@ -0,0 +1,202 @@ +from typing import Any, Union + +import dsp +from dsp.primitives.demonstrate import Example + +from .base_template import BaseTemplate + + +class ExperimentalAdapter(BaseTemplate): + def query(self, example: Example, is_demo: bool = False) -> str: + """Retrieves the input variables from the example and formats them into a query string.""" + result: list[str] = [] + + # If not a demo, find the last field that doesn't have a value set in `example` and set it to "" + # This creates the "Output:" prefix at the end of the prompt. + if not is_demo: + has_value = [ + field.input_variable in example + and example[field.input_variable] is not None + and example[field.input_variable] != "" + for field in self.fields + ] + + if not any(has_value): + assert False, "No input variables found in the example" + + for i in range(1, len(has_value)): + if has_value[i - 1] and not any(has_value[i:]): + example[self.fields[i].input_variable] = "" + break + + for field in self.fields: + if field.input_variable in example and example[field.input_variable] is not None: + if field.input_variable in self.format_handlers: + format_handler = self.format_handlers[field.input_variable] + else: + def format_handler(x): + return str(x).strip() + + formatted_value = format_handler(example[field.input_variable]) + separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator + + result.append(f"{field.name}{separator}{formatted_value}",) + + return "\n\n".join([r for r in result if r]) + + def guidelines(self, show_guidelines=True) -> str: + """Returns the task guidelines as described in the lm prompt""" + if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines): + return "" + + result = "Follow the following format.\n\n" + + example = dsp.Example() + for field in self.fields: + example[field.input_variable] = field.description + example.augmented = True + + result += self.query(example) + return result + + def extract( + self, + example: Union[Example, dict[str, Any]], + raw_pred: str, + ) -> Example: + """Extracts the answer from the LM raw prediction using the template structure + + Args: + example (Union[Example, dict[str, Any]]): Contains the input variables that raw_pred was completed on. + raw_pred (str): LM generated string + + Returns: + Example: The example with the output variables filled in + """ + example = dsp.Example(example) + + raw_pred = raw_pred.strip() + parts = raw_pred.split('\n') + adjusted_parts = [] + for part in parts: + trimmed_part = part.strip() + if trimmed_part: + if adjusted_parts: + adjusted_parts.append('\n' + trimmed_part) + else: + adjusted_parts.append(trimmed_part) + raw_pred = '\n'.join(adjusted_parts) + + idx = 0 + while idx < len(self.fields): + if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None: + break + idx += 1 + + import dspy + + idx = min(idx, len(self.fields) - 1) + while raw_pred != "" and idx < len(self.fields): + if idx < len(self.fields) - 1: + next_field_name = "\n" + self.fields[idx + 1].name + offset = raw_pred.find(next_field_name) + + if offset >= 0: + if dspy.settings.release >= 20231003: + example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip() + raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip() + else: + field_name_parts = self.fields[idx].name.split() + start_pos = 0 + for part in field_name_parts: + pos = raw_pred.find(part.strip()) + if pos != -1: + start_pos = pos + len(part) + else: + break + + example[self.fields[idx].output_variable] = raw_pred[start_pos:offset].strip().rstrip("---").strip() + raw_pred = raw_pred[offset + len(next_field_name) :].strip() + idx += 1 + else: + example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip() + + raw_pred = "" + idx += 1 + break + + else: + assert idx == len(self.fields) - 1, (idx, len(self.fields)) + + if dspy.settings.release >= 20231003: + example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip() + else: + field_name_parts = self.fields[idx].name.split() + start_pos = 0 + for part in field_name_parts: + pos = raw_pred.find(part.strip()) + if pos != -1: + start_pos = pos + len(part) + else: + break + example[self.fields[idx].output_variable] = raw_pred[start_pos:].strip() + + break + + return example + + def __call__(self, example, show_guidelines=True) -> str: + example = dsp.Example(example) + output_fields = [] + for i in range(len(self.fields)): + if self.fields[i].input_variable not in example: + output_field = self.fields[i].input_variable + if output_field not in output_fields: + output_fields.append(self.fields[i].name.split(':')[0]) + + if hasattr(dsp.settings, "query_only") and dsp.settings.query_only: + return self.query(example) + + # The training data should not contain the output variable + assert self.fields[-1].input_variable not in example, f"Output variable {self.fields[-1].input_variable} should not be supplied for querying the LM." + # del example[self.fields[-1].input_variable] + + rdemos = [ + self.query(demo, is_demo=True) + for demo in example.demos + if ( + (not demo.get("augmented", False)) + and ( # validate that the training example has the same primitive input var as the template + self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None + ) + ) + ] + + ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)] + + # Move the rdemos to ademos if rdemo has all the fields filled in + rdemos_ = [] + new_ademos = [] + for rdemo in rdemos: + if all((field.name in rdemo) for field in self.fields if field.input_variable in example): + new_ademos.append(rdemo) + else: + rdemos_.append(rdemo) + + ademos = new_ademos + ademos + rdemos = rdemos_ + + example["augmented"] = True + + query = self.query(example) + parts = [self.instructions, *rdemos, self.guidelines(show_guidelines), *ademos, query,] + + prompt = "\n\n---\n\n".join([p.strip() for p in parts if p]) + prompt_ = prompt[: prompt.rfind("\n")].strip() + + s_or_not = "s" if len(output_fields) > 1 else "" + only_or_not = "only " if len(output_fields) == 1 else "" + + prompt_ += f"\n\nPlease provide the output field{s_or_not} {', '.join(output_fields[:-1]) + (', then ' if len(output_fields) > 2 else ' then ') + output_fields[-1] if len(output_fields) > 1 else output_fields[0]}. Do so immediately, without additional content before or after, and precisely as the format above shows. Begin with {only_or_not}the field {output_fields[0]}." + return prompt_.strip() + diff --git a/dsp/templates/template_v2.py b/dsp/adapters/template.py similarity index 79% rename from dsp/templates/template_v2.py rename to dsp/adapters/template.py index 0df17185bb..b375dfd229 100644 --- a/dsp/templates/template_v2.py +++ b/dsp/adapters/template.py @@ -1,73 +1,12 @@ -import re -from collections import namedtuple from typing import Any, Union import dsp from dsp.primitives.demonstrate import Example -from .utils import format_answers, passages2text +from .base_template import BaseTemplate -Field = namedtuple("Field", "name separator input_variable output_variable description") - -# TODO: de-duplicate with dsp/templates/template.py - - -class TemplateV2: - def __init__( - self, - template, - format_handlers={ - "passages": passages2text, - "context": passages2text, - "answer": format_answers, - "answers": format_answers, - }, - ): - self.format_handlers = format_handlers - - template = template.strip() - - self.instructions = re.search("(.*)\n", template).group(1) - template = template[len(self.instructions) :].strip() - - self.fields = [] - while len(template) > 0: - match = re.search("(.*)(\s){(.*)}\s(.*\${.*})", template) - if match is not None: - name = match.group(1) - separator = match.group(2) - variable = match.group(3) - description = match.group(4) - else: - match = re.search("(.*)(\s){(.*)}", template) - if match is not None: - name = match.group(1) - separator = match.group(2) - variable = match.group(3) - description = None - else: - raise ValueError("Could not parse template") - - var_match = re.match("(.*) -> (.*)", variable) - if var_match is not None: - input_variable = var_match.group(1) - output_variable = var_match.group(2) - else: - input_variable = variable - output_variable = variable - - self.fields.append( - Field( - name=name, - separator=separator, - input_variable=input_variable, - output_variable=output_variable, - description=description, - ), - ) - - template = template[len(match.group(0)) :].strip() +class Template(BaseTemplate): def query(self, example: Example, is_demo: bool = False) -> str: """Retrieves the input variables from the example and formats them into a query string.""" result: list[str] = [] diff --git a/dsp/templates/utils.py b/dsp/adapters/utils.py similarity index 63% rename from dsp/templates/utils.py rename to dsp/adapters/utils.py index 53624ad2d8..4bc15c0585 100644 --- a/dsp/templates/utils.py +++ b/dsp/adapters/utils.py @@ -17,29 +17,29 @@ def passages2text(passages: Union[str, list, tuple]) -> str: return "\n".join([f"[{idx+1}] «{txt}»" for idx, txt in enumerate(passages)]) -def passages2textV2(passages: Union[str, list, tuple]) -> str: - """Formats the given one or more passages into a single structured string.""" - if isinstance(passages, str): - return passages - - assert type(passages) in [list, tuple] - - def psg2text(psg): - try: - title, snippet = psg.split("|", 1) - return f"Title: {title.strip()} | Snippet: «{snippet.strip()}»" - except Exception: - pass +# def passages2textV2(passages: Union[str, list, tuple]) -> str: +# """Formats the given one or more passages into a single structured string.""" +# if isinstance(passages, str): +# return passages + +# assert type(passages) in [list, tuple] + +# def psg2text(psg): +# try: +# title, snippet = psg.split("|", 1) +# return f"Title: {title.strip()} | Snippet: «{snippet.strip()}»" +# except Exception: +# pass - return f"«{psg}»" +# return f"«{psg}»" - if len(passages) == 0: - return "N/A" +# if len(passages) == 0: +# return "N/A" - if len(passages) == 1: - return psg2text(passages[0]) +# if len(passages) == 1: +# return psg2text(passages[0]) - return "\n".join([f"[{idx+1}] {psg2text(txt)}" for idx, txt in enumerate(passages)]) +# return "\n".join([f"[{idx+1}] {psg2text(txt)}" for idx, txt in enumerate(passages)]) def format_answers(answers: Union[str, list]) -> Optional[str]: diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index 7f7966d4bc..3925e83b34 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -26,7 +26,12 @@ def request(self, prompt, **kwargs): return self.basic_request(prompt, **kwargs) def print_green(self, text: str, end: str = "\n"): - return "\x1b[32m" + str(text) + "\x1b[0m" + end + import dspy + + if dspy.settings.experimental: + return "\n\n" + "\x1b[32m" + str(text) + "\x1b[0m" + end + else: + return "\x1b[32m" + str(text) + "\x1b[0m" + end def print_red(self, text: str, end: str = "\n"): return "\x1b[31m" + str(text) + "\x1b[0m" + end @@ -113,7 +118,7 @@ def inspect_history(self, n: int = 1, skip: int = 0): text = choices[0] else: text = choices[0]["text"] - printing_value += self.print_green(text, end="") + printing_value += self.print_green(text.lstrip(), end="") if len(choices) > 1 and isinstance(choices, list): printing_value += self.print_red( @@ -134,4 +139,4 @@ def copy(self, **kwargs): kwargs = {**self.kwargs, **kwargs} model = kwargs.pop("model") - return self.__class__(model=model, **kwargs) + return self.__class__(model=model, **kwargs) \ No newline at end of file diff --git a/dsp/primitives/compiler.py b/dsp/primitives/compiler.py index 625ae23ac7..a004a9834e 100644 --- a/dsp/primitives/compiler.py +++ b/dsp/primitives/compiler.py @@ -1,173 +1,173 @@ -import os -import random -import subprocess -import time +# import os +# import random +# import subprocess +# import time -import tqdm -import ujson -from datasets.fingerprint import Hasher +# import tqdm +# import ujson +# from datasets.fingerprint import Hasher -import dsp +# import dsp -if os.environ.get('DSP_NOTEBOOK_CACHEDIR'): - training_data_directory = os.path.join(os.environ.get('DSP_NOTEBOOK_CACHEDIR'), 'compiler') -else: - training_data_directory = 'cache/compiler' +# if os.environ.get('DSP_NOTEBOOK_CACHEDIR'): +# training_data_directory = os.path.join(os.environ.get('DSP_NOTEBOOK_CACHEDIR'), 'compiler') +# else: +# training_data_directory = 'cache/compiler' -compilations_assumed_to_exist={'ft-zvEdzQVQ5xwlxvNPrxl6kpnw': 'ada:ft-stanfordpraglab-2023-02-09-19-50-49'} +# compilations_assumed_to_exist={'ft-zvEdzQVQ5xwlxvNPrxl6kpnw': 'ada:ft-stanfordpraglab-2023-02-09-19-50-49'} -def openai_check_finetune(jobname): - if dsp.settings.force_reuse_cached_compilation and jobname in compilations_assumed_to_exist: - return compilations_assumed_to_exist[jobname] +# def openai_check_finetune(jobname): +# if dsp.settings.force_reuse_cached_compilation and jobname in compilations_assumed_to_exist: +# return compilations_assumed_to_exist[jobname] - command = f"""openai api fine_tunes.get -i {jobname}""" - print(command) +# command = f"""openai api fine_tunes.get -i {jobname}""" +# print(command) - result = subprocess.run(command.split(), stdout=subprocess.PIPE, check=False) - output = result.stdout.decode("utf-8").strip() +# result = subprocess.run(command.split(), stdout=subprocess.PIPE, check=False) +# output = result.stdout.decode("utf-8").strip() - try: - output = ujson.loads(output) - if output['status'] == 'succeeded': - return output['fine_tuned_model'] +# try: +# output = ujson.loads(output) +# if output['status'] == 'succeeded': +# return output['fine_tuned_model'] - if output['status'] in ['pending', 'running']: - print(f'Compiling, run ```openai api fine_tunes.follow -i {jobname}``` for details...') - time.sleep(60) - return openai_check_finetune(jobname) - except: - pass +# if output['status'] in ['pending', 'running']: +# print(f'Compiling, run ```openai api fine_tunes.follow -i {jobname}``` for details...') +# time.sleep(60) +# return openai_check_finetune(jobname) +# except: +# pass - return False +# return False -def convert_to_training_point2(y, inputs, outputs, template): - assert len(inputs) + len(outputs) == len(template.fields) +# def convert_to_training_point2(y, inputs, outputs, template): +# assert len(inputs) + len(outputs) == len(template.fields) - y_ = dsp.Example(**{f: y[f] for f in inputs}, demos=[]) - prompt = template(y_, show_guidelines=False) +# y_ = dsp.Example(**{f: y[f] for f in inputs}, demos=[]) +# prompt = template(y_, show_guidelines=False) - completion = y[outputs[0]] - output_fields = template.fields[len(inputs):] +# completion = y[outputs[0]] +# output_fields = template.fields[len(inputs):] - for field in output_fields[1:]: - completion += f"\n\n{field.name} " + y[field.output_variable] +# for field in output_fields[1:]: +# completion += f"\n\n{field.name} " + y[field.output_variable] - completion = " " + completion + " " - return {'prompt': prompt, 'completion': completion} +# completion = " " + completion + " " +# return {'prompt': prompt, 'completion': completion} -def simulate(program, input_examples): - training_data = [] +# def simulate(program, input_examples): +# training_data = [] - for input_example in tqdm.tqdm(input_examples): - prediction = program(input_example) +# for input_example in tqdm.tqdm(input_examples): +# prediction = program(input_example) - if prediction is not None: - # assert len(prediction.compiling_stages) == 2, "TMP" - for stage in prediction.compiling_stages: - name, template, inputs, outputs = stage['name'], stage['template'], stage['inputs'], stage['outputs'] - training_data.append(convert_to_training_point2(prediction.get(name), inputs, outputs, template)) +# if prediction is not None: +# # assert len(prediction.compiling_stages) == 2, "TMP" +# for stage in prediction.compiling_stages: +# name, template, inputs, outputs = stage['name'], stage['template'], stage['inputs'], stage['outputs'] +# training_data.append(convert_to_training_point2(prediction.get(name), inputs, outputs, template)) - r = random.Random(0) - r.shuffle(training_data) +# r = random.Random(0) +# r.shuffle(training_data) - return training_data +# return training_data -def openai_finetune_(name, target): - training_data_path = name_to_path(name) +# def openai_finetune_(name, target): +# training_data_path = name_to_path(name) - # Launch the fine-tune on the path - command = f"""openai api fine_tunes.create -t {training_data_path} -m {target} --n_epochs 4 --learning_rate_multiplier 0.05 --no_check_if_files_exist""" - print(command) +# # Launch the fine-tune on the path +# command = f"""openai api fine_tunes.create -t {training_data_path} -m {target} --n_epochs 4 --learning_rate_multiplier 0.05 --no_check_if_files_exist""" +# print(command) - # command = """python script.py""" - process = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) +# # command = """python script.py""" +# process = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - while line := process.stdout.readline().decode().strip(): - if 'created fine-tune:' in line.lower(): - jobname = line.split()[-1] - break +# while line := process.stdout.readline().decode().strip(): +# if 'created fine-tune:' in line.lower(): +# jobname = line.split()[-1] +# break - # if 'costs $' in line.lower(): - # cost = line.split()[-1] - # break +# # if 'costs $' in line.lower(): +# # cost = line.split()[-1] +# # break - # assert cost[0] == '$' +# # assert cost[0] == '$' - # if float(cost[1:]) > 300: - # print(f'Got cost {cost} -- you may wanna cancel the job: openai api fine_tunes.cancel -i {jobname}') +# # if float(cost[1:]) > 300: +# # print(f'Got cost {cost} -- you may wanna cancel the job: openai api fine_tunes.cancel -i {jobname}') - # print(cost) +# # print(cost) - print(jobname) +# print(jobname) - # Block until it's done - ft = openai_check_finetune(jobname) - assert ft, ft +# # Block until it's done +# ft = openai_check_finetune(jobname) +# assert ft, ft - # Return its name - return (jobname, ft) +# # Return its name +# return (jobname, ft) -def openai_finetune(name, target): - print(name) - training_data_path = name_to_path(name) - training_data_path += '.model' +# def openai_finetune(name, target): +# print(name) +# training_data_path = name_to_path(name) +# training_data_path += '.model' - # if path + stuff exists, load the tuple from it - try: - with open(training_data_path) as f: - jobname, ft = ujson.loads(f.readline()) +# # if path + stuff exists, load the tuple from it +# try: +# with open(training_data_path) as f: +# jobname, ft = ujson.loads(f.readline()) - if openai_check_finetune(jobname): - return jobname, ft - except: - pass +# if openai_check_finetune(jobname): +# return jobname, ft +# except: +# pass - jobname, ft = openai_finetune_(name, target) +# jobname, ft = openai_finetune_(name, target) - with open(training_data_path, 'w') as f: - f.write(ujson.dumps((jobname, ft)) + '\n') +# with open(training_data_path, 'w') as f: +# f.write(ujson.dumps((jobname, ft)) + '\n') - return jobname, ft +# return jobname, ft -def name_to_path(name): - if not os.path.exists(training_data_directory): - os.makedirs(training_data_directory) +# def name_to_path(name): +# if not os.path.exists(training_data_directory): +# os.makedirs(training_data_directory) - training_data_path = os.path.join(training_data_directory, f'{name}.jsonl') - return training_data_path +# training_data_path = os.path.join(training_data_directory, f'{name}.jsonl') +# return training_data_path -# 3. Check that the output file name has status "success" (not deleted or non-existent). Otherwise, re-call with n = n+1. -def finetune(training_data, target): - name = Hasher.hash(training_data) - training_data_path = name_to_path(name) +# # 3. Check that the output file name has status "success" (not deleted or non-existent). Otherwise, re-call with n = n+1. +# def finetune(training_data, target): +# name = Hasher.hash(training_data) +# training_data_path = name_to_path(name) - with open(training_data_path, 'w') as f: - for line in training_data: - f.write(ujson.dumps(line) + '\n') +# with open(training_data_path, 'w') as f: +# for line in training_data: +# f.write(ujson.dumps(line) + '\n') - jobname, ft = openai_finetune(name, target) - print(ft) +# jobname, ft = openai_finetune(name, target) +# print(ft) - ft = dsp.GPT3(model=ft, stop=" ") - return ft +# ft = dsp.GPT3(model=ft, stop=" ") +# return ft -# 4. Return updated program. -def compile(program, examples, target='ada'): - training_data = simulate(program, examples) - compiled_lm = finetune(training_data, target=target) +# # 4. Return updated program. +# def compile(program, examples, target='ada'): +# training_data = simulate(program, examples) +# compiled_lm = finetune(training_data, target=target) - def compiled_program(*args, **kwargs): - with dsp.settings.context(compiled_lm=compiled_lm, compiling=False): - return program(*args, **kwargs) +# def compiled_program(*args, **kwargs): +# with dsp.settings.context(compiled_lm=compiled_lm, compiling=False): +# return program(*args, **kwargs) - compiled_program.lm = compiled_lm - return compiled_program +# compiled_program.lm = compiled_lm +# return compiled_program diff --git a/dsp/primitives/demonstrate.py b/dsp/primitives/demonstrate.py index 215afc23d2..331c31517e 100644 --- a/dsp/primitives/demonstrate.py +++ b/dsp/primitives/demonstrate.py @@ -1,4 +1,3 @@ -import random from typing import Any, Callable import numpy as np @@ -45,62 +44,62 @@ def at(example): return self.copy(demos=demos) -def annotate(*transformations): - """Returns an Augment function that applies the provided transformations to the Examples""" +# def annotate(*transformations): +# """Returns an Augment function that applies the provided transformations to the Examples""" - def do_augment(train, k=None, return_all=False): - rdemos = [] - ademos = [] +# def do_augment(train, k=None, return_all=False): +# rdemos = [] +# ademos = [] - for example in train: # tqdm.tqdm - raw_example = dsp.Example(example) +# for example in train: # tqdm.tqdm +# raw_example = dsp.Example(example) - if (k is not None) and len(ademos) >= k: - example = None +# if (k is not None) and len(ademos) >= k: +# example = None - for f in transformations: - if example is None: - break +# for f in transformations: +# if example is None: +# break - example = f(example) +# example = f(example) - if example is not None: - example.augmented = True - ademos.append(example) - else: - raw_example.augmented = False - rdemos.append(raw_example) +# if example is not None: +# example.augmented = True +# ademos.append(example) +# else: +# raw_example.augmented = False +# rdemos.append(raw_example) - if return_all: - return ademos + rdemos +# if return_all: +# return ademos + rdemos - return ademos +# return ademos - return do_augment +# return do_augment -def sample(train: list[Example], k: int): - """Sample k examples from train.""" - rng = random.Random(dsp.settings.branch_idx) - shuffled_train = [dsp.Example(example) for example in train] - rng.shuffle(shuffled_train) +# def sample(train: list[Example], k: int): +# """Sample k examples from train.""" +# rng = random.Random(dsp.settings.branch_idx) +# shuffled_train = [dsp.Example(example) for example in train] +# rng.shuffle(shuffled_train) - return shuffled_train[:k] +# return shuffled_train[:k] -def all_but(train: list[Example], x: Example) -> list[Example]: - """Removes the example x from the train set by comparing the question and history.""" +# def all_but(train: list[Example], x: Example) -> list[Example]: +# """Removes the example x from the train set by comparing the question and history.""" - output = [ - y - for y in train - if not set.intersection( - set(x.get("history", []) + [x.question]), - set(y.get("history", []) + [y.question]), - ) - ] +# output = [ +# y +# for y in train +# if not set.intersection( +# set(x.get("history", []) + [x.question]), +# set(y.get("history", []) + [y.question]), +# ) +# ] - return output +# return output def passage_match(passages: list[str], answers: list[str]) -> bool: diff --git a/dsp/primitives/inspect.py b/dsp/primitives/inspect.py index ebcd3b7872..edc8345d5c 100644 --- a/dsp/primitives/inspect.py +++ b/dsp/primitives/inspect.py @@ -1,91 +1,91 @@ -import inspect -import json -import random -import string - -import requests - - -class FuncInspector: - def __init__(self): - self.calls = [] - - - def inspect_inner(self, func, function_calls): - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - self.merge_result(result, function_calls) - return result - return wrapper - - - def inspect_func(self, func): - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - stack = inspect.stack() - function_calls = [] - for i in range(len(stack)): - if stack[i][3] == "": - break - if stack[i][3] != "wrapper": - function_calls.append(stack[i][3]) - function_calls.reverse() - result = self.inspect_inner(result, function_calls) - return result - return wrapper +# import inspect +# import json +# import random +# import string + +# import requests + + +# class FuncInspector: +# def __init__(self): +# self.calls = [] + + +# def inspect_inner(self, func, function_calls): +# def wrapper(*args, **kwargs): +# result = func(*args, **kwargs) +# self.merge_result(result, function_calls) +# return result +# return wrapper + + +# def inspect_func(self, func): +# def wrapper(*args, **kwargs): +# result = func(*args, **kwargs) +# stack = inspect.stack() +# function_calls = [] +# for i in range(len(stack)): +# if stack[i][3] == "": +# break +# if stack[i][3] != "wrapper": +# function_calls.append(stack[i][3]) +# function_calls.reverse() +# result = self.inspect_inner(result, function_calls) +# return result +# return wrapper - def parse(self, obj, delete_empty=False): - if isinstance(obj, list): - for elem in obj: - self.parse(elem, delete_empty) - if isinstance(obj, dict): - to_delete = [] - for key in obj: - if delete_empty and not obj[key] or key == "completions": - to_delete.append(key) - else: - self.parse(obj[key], delete_empty) - for key in to_delete: - obj.pop(key) - - - def merge_result(self, result, function_calls): - prev_list = self.calls - prev_call = {} if not prev_list else prev_list[-1] - for call in function_calls[:-1]: - if call not in prev_call: - prev_call = {call: []} - prev_list.append(prev_call) - prev_list = prev_call[call] - prev_call = {} if not prev_list else prev_list[-1] - - example_obj = result[0] - self.parse(example_obj) - prev_list.append({ function_calls[-1]: example_obj }) - - - def view_data(self): - chars = string.digits + string.ascii_lowercase - id = ''.join(random.choices(chars, k=8)) - - post_url = 'http://127.0.0.1:5000/log-item' - parsed_calls = self.calls.copy() - self.parse(parsed_calls, delete_empty=True) - data = {'id': id, 'content': parsed_calls} - response = requests.post(post_url, json=data) +# def parse(self, obj, delete_empty=False): +# if isinstance(obj, list): +# for elem in obj: +# self.parse(elem, delete_empty) +# if isinstance(obj, dict): +# to_delete = [] +# for key in obj: +# if delete_empty and not obj[key] or key == "completions": +# to_delete.append(key) +# else: +# self.parse(obj[key], delete_empty) +# for key in to_delete: +# obj.pop(key) + + +# def merge_result(self, result, function_calls): +# prev_list = self.calls +# prev_call = {} if not prev_list else prev_list[-1] +# for call in function_calls[:-1]: +# if call not in prev_call: +# prev_call = {call: []} +# prev_list.append(prev_call) +# prev_list = prev_call[call] +# prev_call = {} if not prev_list else prev_list[-1] + +# example_obj = result[0] +# self.parse(example_obj) +# prev_list.append({ function_calls[-1]: example_obj }) + + +# def view_data(self): +# chars = string.digits + string.ascii_lowercase +# id = ''.join(random.choices(chars, k=8)) + +# post_url = 'http://127.0.0.1:5000/log-item' +# parsed_calls = self.calls.copy() +# self.parse(parsed_calls, delete_empty=True) +# data = {'id': id, 'content': parsed_calls} +# response = requests.post(post_url, json=data) - if response.status_code == 201: - print('Data created successfully') - else: - print(f'Error sending data to server: {response.status_code}') - return +# if response.status_code == 201: +# print('Data created successfully') +# else: +# print(f'Error sending data to server: {response.status_code}') +# return - frontend_url = f"http://localhost:3000?id={id}" - print(f"View the data here, {frontend_url}") +# frontend_url = f"http://localhost:3000?id={id}" +# print(f"View the data here, {frontend_url}") - def output_json(self, out_path): - f = open(out_path, "w") - json_object = json.dumps(self.calls, indent=2) - f.write(json_object) +# def output_json(self, out_path): +# f = open(out_path, "w") +# json_object = json.dumps(self.calls, indent=2) +# f.write(json_object) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index 41edf6c51c..f524a7aaac 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -1,11 +1,9 @@ -from collections import Counter -from typing import Any, Callable, Optional +from typing import Any, Callable import dsp +from dsp.adapters import Template from dsp.primitives.demonstrate import Example -from dsp.templates.template_v3 import Template -from dsp.utils import normalize_text, zipstar -from dsp.utils.utils import dotdict +from dsp.utils import zipstar class Completions: @@ -61,9 +59,7 @@ def _generate(template: Template, **kwargs) -> Callable: generator = dsp.settings.lm - def do_generate( - example: Example, stage: str, max_depth: int = 2, original_example=None, - ): + def do_generate(example: Example, stage: str, max_depth: int = 2, original_example=None): if not dsp.settings.lm: raise AssertionError("No LM is loaded.") original_example = original_example or example @@ -82,9 +78,7 @@ def do_generate( last_field_idx = 0 for field_idx, key in enumerate(field_names): - completions_ = [ - c for c in completions if key in c.keys() and c[key] is not None - ] + completions_ = [c for c in completions if key in c.keys() and c[key] is not None] # Filter out completions that are missing fields that are present in at least one completion. if len(completions_): @@ -100,137 +94,129 @@ def do_generate( # Recurse with greedy decoding and a shorter length. max_tokens = kwargs.get("max_tokens", dsp.settings.lm.kwargs["max_tokens"]) max_tokens = min(max(75, max_tokens // 2), max_tokens) - new_kwargs = { - **kwargs, - "max_tokens": max_tokens, - "n": 1, - "temperature": 0.0, - } + new_kwargs = {**kwargs, "max_tokens": max_tokens, "n": 1, "temperature": 0.0,} assert max_depth > 0 - return generate(template, **new_kwargs)( - completion, - stage=stage, - max_depth=max_depth - 1, - original_example=original_example, - ) + return generate(template, **new_kwargs)(completion, stage=stage, + max_depth=max_depth - 1, + original_example=original_example,) completions = Completions(completions, template=template) example = example.copy(completions=completions) - if len(completions) == 1: - completion = completions[0] - example[stage] = example.copy(**completion) - - if dsp.settings.compiling: - inputs_ = set(original_example.keys()) - inputs = [ - f.input_variable - for f in template.fields - if f.input_variable in inputs_ - ] - outputs = [ - f.output_variable - for f in template.fields - if f.input_variable not in inputs_ - ] - - example.compiling_stages = example.get("compiling_stages", []) - example.compiling_stages.append( - { - "name": stage, - "template": template, - "inputs": inputs, - "outputs": outputs, - }, - ) - else: - # assert not dsp.settings.compiling, "TODO: At this point, cannot compile n>1 generations" - example[stage] = dotdict(completions=completions) + # if len(completions) == 1: + # completion = completions[0] + # example[stage] = example.copy(**completion) + + # if dsp.settings.compiling: + # inputs_ = set(original_example.keys()) + # inputs = [ + # f.input_variable + # for f in template.fields + # if f.input_variable in inputs_ + # ] + # outputs = [ + # f.output_variable + # for f in template.fields + # if f.input_variable not in inputs_ + # ] + + # example.compiling_stages = example.get("compiling_stages", []) + # example.compiling_stages.append( + # { + # "name": stage, + # "template": template, + # "inputs": inputs, + # "outputs": outputs, + # }, + # ) + # else: + # # assert not dsp.settings.compiling, "TODO: At this point, cannot compile n>1 generations" + # example[stage] = dotdict(completions=completions) return example, completions return do_generate -def generate_sc( - example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs, -): - if not dsp.settings.lm: - raise AssertionError("No LM is loaded.") - kwargs = {"temperature": 0.7, "n": 20, "max_tokens": 150, **kwargs} +# def generate_sc( +# example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs, +# ): +# if not dsp.settings.lm: +# raise AssertionError("No LM is loaded.") +# kwargs = {"temperature": 0.7, "n": 20, "max_tokens": 150, **kwargs} - completions = dsp.settings.lm(prompt, **kwargs) - completions = extract_final_answer(example, completions, extract=extract) - return majority_vote_( - completions, normalize=normalize, prediction_field=prediction_field, - ) +# completions = dsp.settings.lm(prompt, **kwargs) +# completions = extract_final_answer(example, completions, extract=extract) +# return majority_vote_( +# completions, normalize=normalize, prediction_field=prediction_field, +# ) -def extract_final_answer(example, completions, extract=None): - if not dsp.settings.lm: - raise AssertionError("No LM is loaded.") - if extract: - completions = [extract(example, p) for p in completions] - else: - completions = [ - p.strip().split("\n")[-1].split(":", 1)[-1].strip() for p in completions - ] +# def extract_final_answer(example, completions, extract=None): +# if not dsp.settings.lm: +# raise AssertionError("No LM is loaded.") +# if extract: +# completions = [extract(example, p) for p in completions] +# else: +# completions = [ +# p.strip().split("\n")[-1].split(":", 1)[-1].strip() for p in completions +# ] - # TODO: make thread-safe? - dsp.settings.lm.history.append( - {**dsp.settings.lm.history[-1], "completions": completions}, - ) +# # TODO: make thread-safe? +# dsp.settings.lm.history.append( +# {**dsp.settings.lm.history[-1], "completions": completions}, +# ) - return completions +# return completions -def majority( - completions: Completions, normalize: bool = True, field: Optional[str] = None, -): - """Returns the most common completion for the target field or the last field in the template.""" - field = completions.template.fields[-1].output_variable if field is None else field +# def majority( +# completions: Completions, normalize: bool = True, field: Optional[str] = None, +# ): +# """Returns the most common completion for the target field or the last field in the template.""" +# field = completions.template.fields[-1].output_variable if field is None else field - return Completions( - majority_vote_(completions, normalize=normalize, prediction_field=field), - template=completions.template, - ) +# return Completions( +# majority_vote_(completions, normalize=normalize, prediction_field=field), +# template=completions.template, +# ) -def majority_vote_(completions: Completions, normalize: bool, prediction_field: str): - """Core logic for majority vote.""" +# def majority_vote_(completions: Completions, normalize: bool, prediction_field: str): +# """Core logic for majority vote.""" - if not dsp.settings.lm: - raise AssertionError("No LM is loaded.") +# if not dsp.settings.lm: +# raise AssertionError("No LM is loaded.") - normalized_to_original = {} - if normalize: - original_completions = completions - completions_ = [] - for pred in completions: - if prediction_field in pred: - completions_.append(normalize_text(pred[prediction_field])) - else: - completions_.append("") - completions = completions_ +# normalized_to_original = {} +# if normalize: +# original_completions = completions +# completions_ = [] +# for pred in completions: +# if prediction_field in pred: +# completions_.append(normalize_text(pred[prediction_field])) +# else: +# completions_.append("") +# completions = completions_ - for completion, normalized_completion in zip(original_completions, completions): - if normalized_completion not in normalized_to_original: - normalized_to_original[normalized_completion] = completion +# for completion, normalized_completion in zip(original_completions, completions): +# if normalized_completion not in normalized_to_original: +# normalized_to_original[normalized_completion] = completion - completions_ = [x for x in completions if x] +# completions_ = [x for x in completions if x] - if completions_: - completions = completions_ +# if completions_: +# completions = completions_ - topk = Counter(completions).most_common() - pred, _ = topk[0] +# topk = Counter(completions).most_common() +# pred, _ = topk[0] - if normalize: - pred = normalized_to_original[pred] +# if normalize: +# pred = normalized_to_original[pred] - dsp.settings.lm.history.append( - {**dsp.settings.lm.history[-1], "topk": topk, "completions": [pred]}, - ) +# dsp.settings.lm.history.append( +# {**dsp.settings.lm.history[-1], "topk": topk, "completions": [pred]}, +# ) - return [pred] +# return [pred] diff --git a/dsp/primitives/primitives.py b/dsp/primitives/primitives.py index 2e4c4bfc07..50aea522b4 100644 --- a/dsp/primitives/primitives.py +++ b/dsp/primitives/primitives.py @@ -1,46 +1,46 @@ -from functools import wraps +# from functools import wraps -import dsp +# import dsp -# applied right to left (innermost first, like function calls) -def compose_decorators(*decorators): - def decorator(func): - for decorator in decorators[::-1]: - func = decorator(func) - return func - return decorator +# # applied right to left (innermost first, like function calls) +# def compose_decorators(*decorators): +# def decorator(func): +# for decorator in decorators[::-1]: +# func = decorator(func) +# return func +# return decorator -def shallow_copy_example_args(func): - @wraps(func) - def wrapper(*args, **kwargs): - args = [dsp.Example(arg) if isinstance(arg, dsp.Example) else arg for arg in args] - kwargs = {key: dsp.Example(value) if isinstance(value, dsp.Example) else value for key, value in kwargs.items()} - return func(*args, **kwargs) - return wrapper +# def shallow_copy_example_args(func): +# @wraps(func) +# def wrapper(*args, **kwargs): +# args = [dsp.Example(arg) if isinstance(arg, dsp.Example) else arg for arg in args] +# kwargs = {key: dsp.Example(value) if isinstance(value, dsp.Example) else value for key, value in kwargs.items()} +# return func(*args, **kwargs) +# return wrapper -transformation = shallow_copy_example_args -# transformation = compose_decorators(handle_compilation, shallow_copy_example_args) +# transformation = shallow_copy_example_args +# # transformation = compose_decorators(handle_compilation, shallow_copy_example_args) -def compiled(func): - def wrapper(*args, **kwargs): - is_to_be_compiled = True #decorator_kwargs.get('compile', False) - compiled_lm = dsp.settings.compiled_lm +# def compiled(func): +# def wrapper(*args, **kwargs): +# is_to_be_compiled = True #decorator_kwargs.get('compile', False) +# compiled_lm = dsp.settings.compiled_lm - if is_to_be_compiled and compiled_lm: - assert len(args) == 1, len(args) - example = args[0] +# if is_to_be_compiled and compiled_lm: +# assert len(args) == 1, len(args) +# example = args[0] - with dsp.settings.context(lm=compiled_lm, show_guidelines=False): - old_demos = list(example.demos) - example = func(example.copy(demos=[]), **kwargs) - return example.copy(demos=old_demos) +# with dsp.settings.context(lm=compiled_lm, show_guidelines=False): +# old_demos = list(example.demos) +# example = func(example.copy(demos=[]), **kwargs) +# return example.copy(demos=old_demos) - with dsp.settings.context(compiling=True): - return func(*args, **kwargs) +# with dsp.settings.context(compiling=True): +# return func(*args, **kwargs) - return wrapper +# return wrapper diff --git a/dsp/templates/__init__.py b/dsp/templates/__init__.py deleted file mode 100644 index 5eaae95f75..0000000000 --- a/dsp/templates/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .template_v2 import * -from .template_v3 import * -from .utils import * - diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 716375b70d..37cab1264b 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -32,7 +32,7 @@ def __new__(cls): reranker=None, compiled_lm=None, force_reuse_cached_compilation=False, - compiling=False, + compiling=False, # TODO: can probably be removed skip_logprobs=False, trace=[], release=0, @@ -41,6 +41,7 @@ def __new__(cls): assert_failures=0, suggest_failures=0, langchain_history=[], + experimental=False, ) cls._instance.__append(config) diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 8b0025b7ad..2003547698 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -8,7 +8,7 @@ from pydantic.fields import FieldInfo import dspy -from dsp.templates import passages2text +from dsp.adapters import passages2text from dspy.primitives.prediction import Prediction from dspy.signatures.signature import ensure_signature, make_signature diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 3325600947..086009e5d2 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -61,6 +61,8 @@ def __call__(self, **kwargs): return self.forward(**kwargs) def forward(self, **kwargs): + assert not dsp.settings.compiling, "It's no longer ever the case that .compiling is True" + # Extract the three privileged keyword arguments. new_signature = ensure_signature(kwargs.pop("new_signature", None)) signature = ensure_signature(kwargs.pop("signature", self.signature)) @@ -83,11 +85,6 @@ def forward(self, **kwargs): config["temperature"] = 0.7 # print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.") - # All of the other kwargs are presumed to fit a prefix of the signature. - # That is, they are input variables for the bottom most generation, so - # we place them inside the input - x - together with the demos. - x = dsp.Example(demos=demos, **kwargs) - if new_signature is not None: signature = new_signature @@ -96,29 +93,10 @@ def forward(self, **kwargs): missing = [k for k in signature.input_fields if k not in kwargs] print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.") - # Switch to legacy format for dsp.generate - template = signature_to_template(signature) - - if self.lm is None: - x, C = dsp.generate(template, **config)(x, stage=self.stage) + if dsp.settings.experimental: + completions = new_generate(lm, signature, dsp.Example(demos=demos, **kwargs), **config) else: - # Note: query_only=True means the instructions and examples are not included. - # I'm not really sure why we'd want to do that, but it's there. - with dsp.settings.context(lm=self.lm, query_only=True): - x, C = dsp.generate(template, **config)(x, stage=self.stage) - - assert self.stage in x, "The generated (input, output) example was not stored" - - completions = [] - - for c in C: - completions.append({}) - for field in template.fields: - if field.output_variable not in kwargs.keys(): - completions[-1][field.output_variable] = getattr( - c, - field.output_variable, - ) + completions = old_generate(demos, signature, kwargs, config, self.lm, self.stage) pred = Prediction.from_completions(completions, signature=signature) @@ -138,6 +116,74 @@ def __repr__(self): return f"{self.__class__.__name__}({self.signature})" + +def old_generate(demos, signature, kwargs, config, lm, stage): + # Switch to legacy format for dsp.generate + x = dsp.Example(demos=demos, **kwargs) + template = signature_to_template(signature) + + if lm is None: + x, C = dsp.generate(template, **config)(x, stage=stage) + else: + # Note: query_only=True means the instructions and examples are not included. + with dsp.settings.context(lm=lm, query_only=True): + x, C = dsp.generate(template, **config)(x, stage=stage) + + # assert stage in x, "The generated (input, output) example was not stored" + + completions = [] + + for c in C: + completions.append({}) + for field in template.fields: + if field.output_variable not in kwargs.keys(): + completions[-1][field.output_variable] = getattr( + c, + field.output_variable, + ) + + return completions + + +def new_generate(lm, signature, example, max_depth=6, **kwargs): + kwargs['stop'] = tuple(kwargs.get('stop', [])) or ('\n---', ) + + # Generate and extract the fields. + template = signature_to_template(signature, adapter=dsp.ExperimentalAdapter) + prompt = template(example) + completions = lm(prompt, **kwargs) + completions = [template.extract(example, p) for p in completions] + + assert all(set(signature.input_fields).issubset(set(c.keys())) for c in completions), "Missing input keys." + + # Find the completions that are most complete. + field_names = [field.input_variable for field in template.fields] + for field_idx, key in enumerate(field_names): + completions_ = [c for c in completions if key in c.keys() and c[key] is not None] + completions = completions_ or completions + if len(completions_) == 0: break + + # If none of the completions is completed (i.e., none has the final field set). + if len(completions_) == 0: + # Pick the first completion that has gone farthest. + completion = completions[0] + + for field_idx_ in range(field_idx+1, len(field_names)): + if field_names[field_idx_] in completion: del completion[field_names[field_idx_]] + + # Recurse with greedy decoding. + new_kwargs = {**kwargs, "n": 1, "temperature": 0.0,} + + assert max_depth > 0 + return new_generate(lm, signature, completion, max_depth=max_depth-1, **new_kwargs) + + # Keep only output fields. + completions = [{k: v for k, v in c.items() if k in signature.output_fields} for c in completions] + + return completions + + + # TODO: get some defaults during init from the context window? # # TODO: FIXME: Hmm, I guess expected behavior is that contexts can # affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates. diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 77b51e622a..228a0a5522 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -14,9 +14,12 @@ from dspy.signatures.field import InputField, OutputField, new_to_old_field -def signature_to_template(signature) -> dsp.Template: +def signature_to_template(signature, adapter=None) -> dsp.Template: """Convert from new to legacy format.""" - return dsp.Template( + + adapter = adapter or dsp.Template + + return adapter( signature.instructions, **{name: new_to_old_field(field) for name, field in signature.fields.items()}, ) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 91f0de1863..69e13d4b35 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -156,7 +156,7 @@ def _bootstrap(self, *, max_bootstraps=None): if success: bootstrapped[example_idx] = True - dspy.logger.info( + print( f"Bootstrapped {len(bootstrapped)} full traces after {example_idx + 1} examples in round {round_idx}.", ) diff --git a/dspy/teleprompt/random_search.py b/dspy/teleprompt/random_search.py index 2b9c4e5925..23e3e03dd6 100644 --- a/dspy/teleprompt/random_search.py +++ b/dspy/teleprompt/random_search.py @@ -1,6 +1,5 @@ import random -import dspy from dspy.evaluate.evaluate import Evaluate from dspy.teleprompt.teleprompt import Teleprompter @@ -55,10 +54,10 @@ def __init__( # self.max_bootstrapped_demos = self.max_num_traces self.max_labeled_demos = max_labeled_demos - dspy.logger.info( + print( "Going to sample between", self.min_num_samples, "and", self.max_num_samples, "traces per predictor.", ) - dspy.logger.info("Will attempt to train", self.num_candidate_sets, "candidate sets.") + print("Will attempt to bootstrap", self.num_candidate_sets, "candidate sets.") def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None, labeled_sample=True): self.trainset = trainset @@ -132,16 +131,16 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None score = 0 if program2._assert_failures > 0 else score ###################################################### - dspy.logger.info("Score:", score, "for set:", [len(predictor.demos) for predictor in program2.predictors()]) + print("Score:", score, "for set:", [len(predictor.demos) for predictor in program2.predictors()]) if len(scores) == 0 or score > max(scores): - dspy.logger.info("New best sscore:", score, "for seed", seed) + print("New best sscore:", score, "for seed", seed) best_program = program2 scores.append(score) - dspy.logger.info(f"Scores so far: {scores}") + print(f"Scores so far: {scores}") - dspy.logger.info("Best score:", max(scores)) + print("Best score:", max(scores)) score_data.append((score, subscores, seed, program2)) @@ -153,17 +152,17 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None transposed_subscores = zip(*[subscores for _, subscores, *_ in top_3_scores if subscores]) avg_of_max_per_entry = sum(max(entry) for entry in transposed_subscores) / len(top_3_scores[0][1]) - dspy.logger.info(f"Average of max per entry across top {k} scores: {avg_of_max_per_entry}") + print(f"Average of max per entry across top {k} scores: {avg_of_max_per_entry}") if self.stop_at_score is not None and score >= self.stop_at_score: - dspy.logger.info(f"Stopping early because score {score} is >= stop_at_score {self.stop_at_score}") + print(f"Stopping early because score {score} is >= stop_at_score {self.stop_at_score}") break # To best program, attach all program candidates in decreasing average score best_program.candidate_programs = score_data best_program.candidate_programs = sorted(best_program.candidate_programs, key=lambda x: x[0], reverse=True) - dspy.logger.info(f"{len(best_program.candidate_programs)} candidate programs found.") + print(f"{len(best_program.candidate_programs)} candidate programs found.") return best_program