diff --git a/dsp/modules/async_gpt3.py b/dsp/modules/async_gpt3.py new file mode 100644 index 0000000000..2f432c2dba --- /dev/null +++ b/dsp/modules/async_gpt3.py @@ -0,0 +1,94 @@ +import json +from typing import Any, cast + +import backoff +import openai +import openai.error +from openai.openai_object import OpenAIObject + + +from dsp.modules.gpt3 import GPT3, backoff_hdlr + + +class AsyncGPT3(GPT3): + """Wrapper around OpenAI's GPT API. Supports both the OpenAI and Azure APIs. + + Args: + model (str, optional): OpenAI or Azure supported LLM model to use. Defaults to "text-davinci-002". + api_key (Optional[str], optional): API provider Authentication token. use Defaults to None. + api_provider (Literal["openai", "azure"], optional): The API provider to use. Defaults to "openai". + model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text". + **kwargs: Additional arguments to pass to the API provider. + """ + + async def basic_request(self, prompt: str, **kwargs) -> OpenAIObject: + raw_kwargs = kwargs + + kwargs = {**self.kwargs, **kwargs} + if self.model_type == "chat": + # caching mechanism requires hashable kwargs + kwargs["messages"] = [{"role": "user", "content": prompt}] + kwargs = {"stringify_request": json.dumps(kwargs)} + response = await _a_gpt3_chat_request(**kwargs) + + else: + kwargs["prompt"] = prompt + response = await _a_gpt3_completion_request(**kwargs) + + self._add_to_history(prompt, response, kwargs, raw_kwargs) + + return response + + @backoff.on_exception( + backoff.expo, + (openai.error.RateLimitError, openai.error.ServiceUnavailableError), + max_time=1000, + on_backoff=backoff_hdlr, + ) + async def request(self, prompt: str, **kwargs) -> OpenAIObject: + """Handles retreival of GPT-3 completions whilst handling rate limiting and caching.""" + if "model_type" in kwargs: + del kwargs["model_type"] + + return await self.basic_request(prompt, **kwargs) + + async def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, + ) -> list[dict[str, Any]]: + """Retrieves completions from GPT-3. + + Args: + prompt (str): prompt to send to GPT-3 + only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True. + return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False. + + Returns: + list[dict[str, Any]]: list of completion choices + """ + + assert only_completed, "for now" + assert return_sorted is False, "for now" + + response = await self.request(prompt, **kwargs) + completions = self._get_completions_from_response( + response=response, + only_completed=only_completed, + return_sorted=return_sorted, + **kwargs, + ) + return completions + + +async def _a_gpt3_completion_request(**kwargs): + return await openai.Completion.acreate(**kwargs) + + +async def _a_gpt3_chat_request(**kwargs) -> OpenAIObject: + if "stringify_request" in kwargs: + kwargs = json.loads(kwargs["stringify_request"]) + res = await openai.ChatCompletion.acreate(**kwargs) + return cast(OpenAIObject, res) diff --git a/dsp/modules/gpt3.py b/dsp/modules/gpt3.py index f0c7b6679f..150c4c1328 100644 --- a/dsp/modules/gpt3.py +++ b/dsp/modules/gpt3.py @@ -42,7 +42,12 @@ def __init__( super().__init__(model) self.provider = "openai" - default_model_type = "chat" if ('gpt-3.5' in model or 'turbo' in model or 'gpt-4' in model) and ('instruct' not in model) else "text" + default_model_type = ( + "chat" + if ("gpt-3.5" in model or "turbo" in model or "gpt-4" in model) + and ("instruct" not in model) + else "text" + ) self.model_type = model_type if model_type else default_model_type if api_provider == "azure": @@ -70,7 +75,7 @@ def __init__( "n": 1, **kwargs, } # TODO: add kwargs above for - + if api_provider != "azure": self.kwargs["model"] = model self.history: list[dict[str, Any]] = [] @@ -78,6 +83,17 @@ def __init__( def _openai_client(): return openai + def _add_to_history( + self, prompt: str, response: OpenAIObject, kwargs: dict, raw_kwargs: dict + ): + history = { + "prompt": prompt, + "response": response, + "kwargs": kwargs, + "raw_kwargs": raw_kwargs, + } + self.history.append(history) + def basic_request(self, prompt: str, **kwargs) -> OpenAIObject: raw_kwargs = kwargs @@ -85,22 +101,14 @@ def basic_request(self, prompt: str, **kwargs) -> OpenAIObject: if self.model_type == "chat": # caching mechanism requires hashable kwargs kwargs["messages"] = [{"role": "user", "content": prompt}] - kwargs = { - "stringify_request": json.dumps(kwargs) - } + kwargs = {"stringify_request": json.dumps(kwargs)} response = cached_gpt3_turbo_request(**kwargs) - + else: kwargs["prompt"] = prompt response = cached_gpt3_request(**kwargs) - history = { - "prompt": prompt, - "response": response, - "kwargs": kwargs, - "raw_kwargs": raw_kwargs, - } - self.history.append(history) + self._add_to_history(prompt, response, kwargs, raw_kwargs) return response @@ -114,7 +122,7 @@ def request(self, prompt: str, **kwargs) -> OpenAIObject: """Handles retreival of GPT-3 completions whilst handling rate limiting and caching.""" if "model_type" in kwargs: del kwargs["model_type"] - + return self.basic_request(prompt, **kwargs) def _get_choice_text(self, choice: dict[str, Any]) -> str: @@ -122,34 +130,13 @@ def _get_choice_text(self, choice: dict[str, Any]) -> str: return choice["message"]["content"] return choice["text"] - def __call__( + def _get_completions_from_response( self, - prompt: str, + response: OpenAIObject, only_completed: bool = True, return_sorted: bool = False, **kwargs, ) -> list[dict[str, Any]]: - """Retrieves completions from GPT-3. - - Args: - prompt (str): prompt to send to GPT-3 - only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True. - return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False. - - Returns: - list[dict[str, Any]]: list of completion choices - """ - - assert only_completed, "for now" - assert return_sorted is False, "for now" - - # if kwargs.get("n", 1) > 1: - # if self.model_type == "chat": - # kwargs = {**kwargs} - # else: - # kwargs = {**kwargs, "logprobs": 5} - - response = self.request(prompt, **kwargs) choices = response["choices"] completed_choices = [c for c in choices if c["finish_reason"] != "length"] @@ -180,6 +167,42 @@ def __call__( return completions + def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, + ) -> list[dict[str, Any]]: + """Retrieves completions from GPT-3. + + Args: + prompt (str): prompt to send to GPT-3 + only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True. + return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False. + + Returns: + list[dict[str, Any]]: list of completion choices + """ + + assert only_completed, "for now" + assert return_sorted is False, "for now" + + # if kwargs.get("n", 1) > 1: + # if self.model_type == "chat": + # kwargs = {**kwargs} + # else: + # kwargs = {**kwargs, "logprobs": 5} + + response = self.request(prompt, **kwargs) + completions = self._get_completions_from_response( + response=response, + only_completed=only_completed, + return_sorted=return_sorted, + **kwargs, + ) + return completions + @CacheMemory.cache def cached_gpt3_request_v2(**kwargs): diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index 73dbee6438..b675be5f14 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -7,7 +7,7 @@ class ProgramMeta(type): pass # def __call__(cls, *args, **kwargs): # obj = super(ProgramMeta, cls).__call__(*args, **kwargs) - + # if issubclass(cls, Program) and not getattr(obj, "_program_init_called", False): # obj._base_init() # obj._program_init_called = True @@ -23,41 +23,32 @@ def __init__(self): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - + def named_predictors(self): from dspy.predict.predict import Predict - + named_parameters = self.named_parameters() - return [(name, param) for name, param in named_parameters if isinstance(param, Predict)] + return [ + (name, param) + for name, param in named_parameters + if isinstance(param, Predict) + ] def predictors(self): return [param for _, param in self.named_predictors()] - + def __repr__(self): s = [] for name, param in self.named_predictors(): s.append(f"{name} = {param}") - - return '\n'.join(s) - - # def __deepcopy__(self, memo): - # # memo is a dict of id's to copies already made during the current call - # # Check if the object is already copied - # if id(self) in memo: - # return memo[id(self)] - - # print(f"Deep copying {self.__class__.__name__}...") - # new_copy = copy.copy(self) - # memo[id(self)] = new_copy + return "\n".join(s) - # for k, v in self.__dict__.items(): - # print(f"Copying attribute {k} of type {type(v)}...") - # setattr(new_copy, k, copy.deepcopy(v, memo)) - # print("Done") - # return new_copy +class AsyncModule(Module): + async def __call__(self, *args, **kwargs): + return await self.forward(*args, **kwargs) -Program = Module \ No newline at end of file +Program = Module diff --git a/dspy/teleprompt/async_bootstrap.py b/dspy/teleprompt/async_bootstrap.py new file mode 100644 index 0000000000..a712bcc8f3 --- /dev/null +++ b/dspy/teleprompt/async_bootstrap.py @@ -0,0 +1,87 @@ +import dsp +import tqdm +import random + +from dspy.primitives import Example +from dspy.teleprompt.bootstrap import BootstrapFewShot + + +class AsyncBootstrapFewShot(BootstrapFewShot): + async def compile(self, student, *, teacher=None, trainset, valset=None): + self.prepare_for_bootstrap( + student, teacher=teacher, trainset=trainset, valset=valset + ) + await self._bootstrap() + + self.student = self._train() + self.student._compiled = True + + return self.student + + async def _bootstrap(self, *, max_bootsraps=None): + max_bootsraps = max_bootsraps or self.max_bootstrapped_demos + + bootstrapped = {} + self.name2traces = {name: [] for name in self.name2predictor} + + for round_idx in range(self.max_rounds): + for example_idx, example in enumerate(tqdm.tqdm(self.trainset)): + if len(bootstrapped) >= max_bootsraps: + break + + if example_idx not in bootstrapped: + success = await self._bootstrap_one_example(example, round_idx) + + if success: + bootstrapped[example_idx] = True + + print( + f"Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}." + ) + + # Unbootstrapped training examples + + self.validation = [ + x for idx, x in enumerate(self.trainset) if idx not in bootstrapped + ] + random.Random(0).shuffle(self.validation) + + self.validation = self.valset or self.validation + + # NOTE: Can't yet use evaluate because we need to trace *per example* + # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12) + # score = evaluate(self.metric, display_table=False, display_progress=True) + + async def _bootstrap_one_example(self, example, round_idx=0): + name2traces = self.name2traces + teacher = self.teacher # .deepcopy() + predictor_cache = {} + + try: + with dsp.settings.context(trace=[], **self.teacher_settings): + new_settings = self._make_new_settings(round_idx) + + with dsp.settings.context(**new_settings): + self._cache_and_update_predictor_demos( + teacher, example, predictor_cache + ) + + prediction = await teacher(**example.inputs()) + trace = dsp.settings.trace + self._restore_predictor_demos_from_cache(teacher, predictor_cache) + + success = (self.metric is None) or self.metric( + example, prediction, trace + ) + # print(success, example, prediction) + except Exception as e: + success = False + # FIXME: remove the reliance on uuid here so the error is printed + print( + f"Failed to run or to evaluate example {example} with {self.metric} due to {e}." + ) + + if success: + self._make_successful_demos(trace, example, name2traces) + + return success diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 14588592d9..2456aac313 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -30,7 +30,14 @@ class BootstrapFewShot(Teleprompter): - def __init__(self, metric=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1): + def __init__( + self, + metric=None, + teacher_settings={}, + max_bootstrapped_demos=4, + max_labeled_demos=16, + max_rounds=1, + ): self.metric = metric self.teacher_settings = teacher_settings @@ -38,41 +45,60 @@ def __init__(self, metric=None, teacher_settings={}, max_bootstrapped_demos=4, m self.max_labeled_demos = max_labeled_demos self.max_rounds = max_rounds - def compile(self, student, *, teacher=None, trainset, valset=None): + def prepare_for_bootstrap(self, student, *, teacher=None, trainset, valset=None): self.trainset = trainset self.valset = valset self._prepare_student_and_teacher(student, teacher) self._prepare_predictor_mappings() + + def compile(self, student, *, teacher=None, trainset, valset=None): + self.prepare_for_bootstrap( + student, teacher=teacher, trainset=trainset, valset=valset + ) self._bootstrap() self.student = self._train() self.student._compiled = True return self.student - + def _prepare_student_and_teacher(self, student, teacher): self.student = student.reset_copy() - self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy() + self.teacher = ( + teacher.deepcopy() if teacher is not None else student.reset_copy() + ) assert self.student._compiled is False, "Student must be uncompiled." if self.max_labeled_demos and self.teacher._compiled is False: teleprompter = LabeledFewShot(k=self.max_labeled_demos) - self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset) + self.teacher = teleprompter.compile( + self.teacher.reset_copy(), trainset=self.trainset + ) def _prepare_predictor_mappings(self): name2predictor, predictor2name = {}, {} student, teacher = self.student, self.teacher - assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors." - - for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()): - assert name1 == name2, "Student and teacher must have the same program structure." - assert predictor1.signature == predictor2.signature, f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" - assert id(predictor1) != id(predictor2), "Student and teacher must be different objects." - - name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2) + assert len(student.predictors()) == len( + teacher.predictors() + ), "Student and teacher must have the same number of predictors." + + for (name1, predictor1), (name2, predictor2) in zip( + student.named_predictors(), teacher.named_predictors() + ): + assert ( + name1 == name2 + ), "Student and teacher must have the same program structure." + assert ( + predictor1.signature == predictor2.signature + ), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}" + assert id(predictor1) != id( + predictor2 + ), "Student and teacher must be different objects." + + name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2) predictor2name[id(predictor1)] = name1 predictor2name[id(predictor2)] = name2 @@ -95,12 +121,16 @@ def _bootstrap(self, *, max_bootsraps=None): if success: bootstrapped[example_idx] = True - - print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.') - + + print( + f"Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}." + ) + # Unbootstrapped training examples - self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped] + self.validation = [ + x for idx, x in enumerate(self.trainset) if idx not in bootstrapped + ] random.Random(0).shuffle(self.validation) self.validation = self.valset or self.validation @@ -108,59 +138,82 @@ def _bootstrap(self, *, max_bootsraps=None): # NOTE: Can't yet use evaluate because we need to trace *per example* # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12) # score = evaluate(self.metric, display_table=False, display_progress=True) - + + def _make_successful_demos(self, trace, example, name2traces): + for step in trace: + predictor, inputs, outputs = step + + if "dspy_uuid" in example: + demo = Example( + augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs + ) + else: + # TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case. + demo = Example(augmented=True, **inputs, **outputs) + + try: + predictor_name = self.predictor2name[id(predictor)] + except KeyError as e: + continue # FIXME: ! + + # TODO: Look closer into this. It's a bit tricky to reproduce. + print(f"Failed to find predictor {predictor} in {self.predictor2name}.") + print( + "Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells." + ) + print("Try restarting the notebook, or open an issue.") + raise KeyError( + f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}." + ) from e + + name2traces[predictor_name].append(demo) + + def _make_new_settings(self, round_idx) -> dict: + lm = dsp.settings.lm + lm = lm.copy(temperature=0.7 + 0.001 * round_idx) if round_idx > 0 else lm + return dict(lm=lm) if round_idx > 0 else {} + + def _cache_and_update_predictor_demos(teacher, example, predictor_cache): + for name, predictor in teacher.named_predictors(): + predictor_cache[name] = predictor.demos + predictor.demos = [x for x in predictor.demos if x != example] + + def _restore_predictor_demos_from_cache(teacher, predictor_cache): + for name, predictor in teacher.named_predictors(): + predictor.demos = predictor_cache[name] + def _bootstrap_one_example(self, example, round_idx=0): name2traces = self.name2traces - teacher = self.teacher #.deepcopy() + teacher = self.teacher # .deepcopy() predictor_cache = {} try: with dsp.settings.context(trace=[], **self.teacher_settings): - lm = dsp.settings.lm - lm = lm.copy(temperature=0.7 + 0.001 * round_idx) if round_idx > 0 else lm - new_settings = dict(lm=lm) if round_idx > 0 else {} + new_settings = self._make_new_settings(round_idx) with dsp.settings.context(**new_settings): - for name, predictor in teacher.named_predictors(): - predictor_cache[name] = predictor.demos - predictor.demos = [x for x in predictor.demos if x != example] + self._cache_and_update_predictor_demos( + teacher, example, predictor_cache + ) prediction = teacher(**example.inputs()) trace = dsp.settings.trace + self._restore_predictor_demos_from_cache(teacher, predictor_cache) - for name, predictor in teacher.named_predictors(): - predictor.demos = predictor_cache[name] - - success = (self.metric is None) or self.metric(example, prediction, trace) + success = (self.metric is None) or self.metric( + example, prediction, trace + ) # print(success, example, prediction) except Exception as e: success = False # FIXME: remove the reliance on uuid here so the error is printed - print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.') - + print( + f"Failed to run or to evaluate example {example} with {self.metric} due to {e}." + ) + if success: - for step in trace: - predictor, inputs, outputs = step - - if 'dspy_uuid' in example: - demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs) - else: - # TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case. - demo = Example(augmented=True, **inputs, **outputs) - - try: - predictor_name = self.predictor2name[id(predictor)] - except KeyError as e: - continue # FIXME: ! - - # TODO: Look closer into this. It's a bit tricky to reproduce. - print(f'Failed to find predictor {predictor} in {self.predictor2name}.') - print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.') - print('Try restarting the notebook, or open an issue.') - raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e - - name2traces[predictor_name].append(demo) - + self._make_successful_demos(trace, example, name2traces) + return success def _train(self): @@ -168,14 +221,17 @@ def _train(self): raw_demos = self.validation for name, predictor in self.student.named_predictors(): - augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos] - - sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos)) + augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos] + + sample_size = min( + self.max_labeled_demos - len(augmented_demos), len(raw_demos) + ) sample_size = max(0, sample_size) raw_demos = rng.sample(raw_demos, sample_size) - + import dspy + if dspy.settings.release >= 20230928: predictor.demos = raw_demos + augmented_demos else: