Skip to content
94 changes: 94 additions & 0 deletions dsp/modules/async_gpt3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
from typing import Any, cast

import backoff
import openai
import openai.error
from openai.openai_object import OpenAIObject


from dsp.modules.gpt3 import GPT3, backoff_hdlr


class AsyncGPT3(GPT3):
"""Wrapper around OpenAI's GPT API. Supports both the OpenAI and Azure APIs.

Args:
model (str, optional): OpenAI or Azure supported LLM model to use. Defaults to "text-davinci-002".
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
api_provider (Literal["openai", "azure"], optional): The API provider to use. Defaults to "openai".
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text".
**kwargs: Additional arguments to pass to the API provider.
"""

async def basic_request(self, prompt: str, **kwargs) -> OpenAIObject:
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}
if self.model_type == "chat":
# caching mechanism requires hashable kwargs
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs = {"stringify_request": json.dumps(kwargs)}
response = await _a_gpt3_chat_request(**kwargs)

else:
kwargs["prompt"] = prompt
response = await _a_gpt3_completion_request(**kwargs)

self._add_to_history(prompt, response, kwargs, raw_kwargs)

return response

@backoff.on_exception(
backoff.expo,
(openai.error.RateLimitError, openai.error.ServiceUnavailableError),
max_time=1000,
on_backoff=backoff_hdlr,
)
async def request(self, prompt: str, **kwargs) -> OpenAIObject:
"""Handles retreival of GPT-3 completions whilst handling rate limiting and caching."""
if "model_type" in kwargs:
del kwargs["model_type"]

return await self.basic_request(prompt, **kwargs)

async def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Retrieves completions from GPT-3.

Args:
prompt (str): prompt to send to GPT-3
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

Returns:
list[dict[str, Any]]: list of completion choices
"""

assert only_completed, "for now"
assert return_sorted is False, "for now"

response = await self.request(prompt, **kwargs)
completions = self._get_completions_from_response(
response=response,
only_completed=only_completed,
return_sorted=return_sorted,
**kwargs,
)
return completions


async def _a_gpt3_completion_request(**kwargs):
return await openai.Completion.acreate(**kwargs)


async def _a_gpt3_chat_request(**kwargs) -> OpenAIObject:
if "stringify_request" in kwargs:
kwargs = json.loads(kwargs["stringify_request"])
res = await openai.ChatCompletion.acreate(**kwargs)
return cast(OpenAIObject, res)
97 changes: 60 additions & 37 deletions dsp/modules/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def __init__(
super().__init__(model)
self.provider = "openai"

default_model_type = "chat" if ('gpt-3.5' in model or 'turbo' in model or 'gpt-4' in model) and ('instruct' not in model) else "text"
default_model_type = (
"chat"
if ("gpt-3.5" in model or "turbo" in model or "gpt-4" in model)
and ("instruct" not in model)
else "text"
)
self.model_type = model_type if model_type else default_model_type

if api_provider == "azure":
Expand Down Expand Up @@ -70,37 +75,40 @@ def __init__(
"n": 1,
**kwargs,
} # TODO: add kwargs above for </s>

if api_provider != "azure":
self.kwargs["model"] = model
self.history: list[dict[str, Any]] = []

def _openai_client():
return openai

def _add_to_history(
self, prompt: str, response: OpenAIObject, kwargs: dict, raw_kwargs: dict
):
history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)

def basic_request(self, prompt: str, **kwargs) -> OpenAIObject:
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}
if self.model_type == "chat":
# caching mechanism requires hashable kwargs
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs = {
"stringify_request": json.dumps(kwargs)
}
kwargs = {"stringify_request": json.dumps(kwargs)}
response = cached_gpt3_turbo_request(**kwargs)

else:
kwargs["prompt"] = prompt
response = cached_gpt3_request(**kwargs)

history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)
self._add_to_history(prompt, response, kwargs, raw_kwargs)

return response

Expand All @@ -114,42 +122,21 @@ def request(self, prompt: str, **kwargs) -> OpenAIObject:
"""Handles retreival of GPT-3 completions whilst handling rate limiting and caching."""
if "model_type" in kwargs:
del kwargs["model_type"]

return self.basic_request(prompt, **kwargs)

def _get_choice_text(self, choice: dict[str, Any]) -> str:
if self.model_type == "chat":
return choice["message"]["content"]
return choice["text"]

def __call__(
def _get_completions_from_response(
self,
prompt: str,
response: OpenAIObject,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Retrieves completions from GPT-3.

Args:
prompt (str): prompt to send to GPT-3
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

Returns:
list[dict[str, Any]]: list of completion choices
"""

assert only_completed, "for now"
assert return_sorted is False, "for now"

# if kwargs.get("n", 1) > 1:
# if self.model_type == "chat":
# kwargs = {**kwargs}
# else:
# kwargs = {**kwargs, "logprobs": 5}

response = self.request(prompt, **kwargs)
choices = response["choices"]

completed_choices = [c for c in choices if c["finish_reason"] != "length"]
Expand Down Expand Up @@ -180,6 +167,42 @@ def __call__(

return completions

def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Retrieves completions from GPT-3.

Args:
prompt (str): prompt to send to GPT-3
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

Returns:
list[dict[str, Any]]: list of completion choices
"""

assert only_completed, "for now"
assert return_sorted is False, "for now"

# if kwargs.get("n", 1) > 1:
# if self.model_type == "chat":
# kwargs = {**kwargs}
# else:
# kwargs = {**kwargs, "logprobs": 5}

response = self.request(prompt, **kwargs)
completions = self._get_completions_from_response(
response=response,
only_completed=only_completed,
return_sorted=return_sorted,
**kwargs,
)
return completions


@CacheMemory.cache
def cached_gpt3_request_v2(**kwargs):
Expand Down
37 changes: 14 additions & 23 deletions dspy/primitives/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class ProgramMeta(type):
pass
# def __call__(cls, *args, **kwargs):
# obj = super(ProgramMeta, cls).__call__(*args, **kwargs)

# if issubclass(cls, Program) and not getattr(obj, "_program_init_called", False):
# obj._base_init()
# obj._program_init_called = True
Expand All @@ -23,41 +23,32 @@ def __init__(self):

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def named_predictors(self):
from dspy.predict.predict import Predict

named_parameters = self.named_parameters()
return [(name, param) for name, param in named_parameters if isinstance(param, Predict)]
return [
(name, param)
for name, param in named_parameters
if isinstance(param, Predict)
]

def predictors(self):
return [param for _, param in self.named_predictors()]

def __repr__(self):
s = []

for name, param in self.named_predictors():
s.append(f"{name} = {param}")

return '\n'.join(s)

# def __deepcopy__(self, memo):
# # memo is a dict of id's to copies already made during the current call
# # Check if the object is already copied
# if id(self) in memo:
# return memo[id(self)]

# print(f"Deep copying {self.__class__.__name__}...")

# new_copy = copy.copy(self)
# memo[id(self)] = new_copy
return "\n".join(s)

# for k, v in self.__dict__.items():
# print(f"Copying attribute {k} of type {type(v)}...")
# setattr(new_copy, k, copy.deepcopy(v, memo))
# print("Done")

# return new_copy
class AsyncModule(Module):
async def __call__(self, *args, **kwargs):
return await self.forward(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this actually work and become truly async? The call to the LLM is still completely synchronous/blocking.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for engaging! It depends on how you define forward(). The goal is just to allow you to do async ops, which could be LM-related or not (e.g. async db reads), in the module. For example,

class RAG(dspy.AsyncModule):
    def __init__(self):
        super().__init__()
    
    async def forward(self, question):
        search_entity_batches = await asyncio.gather(*[search_fn_1, search_fn_2, ...])
        # or could be async ranking, async ensemble answer generation, etc.
        ....

You're right in that it doesn't solve async for the current Predict and Retrieve abstractions, which use a sync LLM completion api call. Those would require a larger refactor to make async, but I think this is a reasonable escape hatch in the meantime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very reasonable. I can merge a very small change like this for sure but right now it's affecting a lot of lines of code



Program = Module
Program = Module
Loading