diff --git a/docs/api/language_model_clients/PremAI.md b/docs/api/language_model_clients/PremAI.md new file mode 100644 index 0000000000..1bc7d90474 --- /dev/null +++ b/docs/api/language_model_clients/PremAI.md @@ -0,0 +1,70 @@ +--- +sidebar_position: 5 +--- + +# dsp.PremAI + +[PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application. + +### Prerequisites + +Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key. + +### Usage + +Please make sure you have premai python sdk installed. Otherwise you can do it using this command: + +```bash +pip install -U premai +``` + +Here is a quick example on how to use premai python sdk with dspy + +```python +from dspy import PremAI + +llm = PremAI(model='mistral-tiny', project_id=123, api_key="your-premai-api-key") +print(llm("what is a large language model")) +``` + +> Please note: Project ID 123 is just an example. You can find your project ID inside our platform under which you created your project. + +### Constructor + +The constructor initializes the base class `LM` and verifies the `api_key` provided or defined through the `PREMAI_API_KEY` environment variable. + +```python +class PremAI(LM): + def __init__( + self, + model: str, + project_id: int, + api_key: str, + base_url: Optional[str] = None, + session_id: Optional[int] = None, + **kwargs, + ) -> None: +``` + +**Parameters:** + +- `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad). +- `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice. +- `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None. +- `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history. +- `**kwargs`: Additional language model arguments will be passed to the API provider. + +### Methods + +#### `__call__(self, prompt: str, **kwargs) -> List[Dict[str, Any]]` + +Retrieves completions from PremAI by calling `request`. + +Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response. + +After generation, the completions are post-processed based on the `model_type` parameter. + +**Parameters:** + +- `prompt` (_str_): Prompt to send to PremAI. +- `**kwargs`: Additional keyword arguments for completion request. Example: parameters like `temperature`, `max_tokens` etc. You can find all the additional kwargs [here](https://docs.premai.io/get-started/sdk#optional-parameters). diff --git a/docs/docs/building-blocks/1-language_models.md b/docs/docs/building-blocks/1-language_models.md index ddf5f815e0..25fb134c47 100644 --- a/docs/docs/building-blocks/1-language_models.md +++ b/docs/docs/building-blocks/1-language_models.md @@ -137,6 +137,7 @@ lm = dspy.{provider_listed_below}(model="your model", model_request_kwargs="..." 4. `dspy.Together` for hosted various open source models. +5. `dspy.PremAI` for hosted best open source and closed source models. ### Local LMs. @@ -173,4 +174,4 @@ model = 'dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1' model_path = 'dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so' llama = dspy.ChatModuleClient(model=model, model_path=model_path) -``` +``` \ No newline at end of file diff --git a/docs/docs/deep-dive/language_model_clients/remote_models/PremAI.mdx b/docs/docs/deep-dive/language_model_clients/remote_models/PremAI.mdx new file mode 100644 index 0000000000..f41bae1394 --- /dev/null +++ b/docs/docs/deep-dive/language_model_clients/remote_models/PremAI.mdx @@ -0,0 +1,70 @@ +## PremAI + +[PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application. + +### Prerequisites + +Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key. + +### Setting up the PremAI Client + +The constructor initializes the base class `LM` to support prompting requests to supported PremAI hosted models. This requires the following parameters: + +- `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad). +- `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice. +- `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None. +- `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history. +- `**kwargs`: Additional language model arguments will be passed to the API provider. + +Example of PremAI constructor: + +```python +class PremAI(LM): + def __init__( + self, + model: str, + project_id: int, + api_key: str, + base_url: Optional[str] = None, + session_id: Optional[int] = None, + **kwargs, + ) -> None: +``` + +### Under the Hood + +#### `__call__(self, prompt: str, **kwargs) -> str` + +**Parameters:** +- `prompt` (_str_): Prompt to send to PremAI. +- `**kwargs`: Additional keyword arguments for completion request. + +**Returns:** +- `str`: Completions string from the chosen LLM provider + +Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response. + +### Using the PremAI client + +```python +premai_client = dspy.PremAI(project_id=1111) +``` + +Please note that, this is a dummy `project_id`. You need to change this to the project_id you are interested to use with dspy. + +```python +dspy.configure(lm=premai_client) + +#Example DSPy CoT QA program +qa = dspy.ChainOfThought('question -> answer') + +response = qa(question="What is the capital of Paris?") +print(response.answer) +``` + +2) Generate responses using the client directly. + +```python +response = premai_client(prompt='What is the capital of Paris?') +print(response) +``` \ No newline at end of file diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index d02c59cbd2..a31c72911f 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -20,6 +20,7 @@ from .hf_client import Anyscale, HFClientTGI, Together from .mistral import * from .ollama import * +from .premai import PremAI from .pyserini import * from .sbert import * from .sentence_vectorizer import * diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index f56df569ea..8bed65aa36 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -52,10 +52,15 @@ def inspect_history(self, n: int = 1, skip: int = 0): or provider == "groq" or provider == "Bedrock" or provider == "Sagemaker" + or provider == "premai" ): printed.append((prompt, x["response"])) elif provider == "anthropic": - blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"] + blocks = [ + {"text": block.text} + for block in x["response"].content + if block.type == "text" + ] printed.append((prompt, blocks)) elif provider == "cohere": printed.append((prompt, x["response"].text)) @@ -85,7 +90,7 @@ def inspect_history(self, n: int = 1, skip: int = 0): if provider == "cohere" or provider == "Bedrock" or provider == "Sagemaker": text = choices elif provider == "openai" or provider == "ollama": - text = ' ' + self._get_choice_text(choices[0]).strip() + text = " " + self._get_choice_text(choices[0]).strip() elif provider == "clarifai" or provider == "claude": text = choices elif provider == "groq": @@ -96,14 +101,16 @@ def inspect_history(self, n: int = 1, skip: int = 0): text = choices[0].message.content elif provider == "cloudflare": text = choices[0] - elif provider == "ibm": + elif provider == "ibm" or provider == "premai": text = choices else: text = choices[0]["text"] printing_value += self.print_green(text, end="") if len(choices) > 1 and isinstance(choices, list): - printing_value += self.print_red(f" \t (and {len(choices)-1} other completions)", end="") + printing_value += self.print_red( + f" \t (and {len(choices)-1} other completions)", end="", + ) printing_value += "\n\n\n" diff --git a/dsp/modules/premai.py b/dsp/modules/premai.py new file mode 100644 index 0000000000..3c9589cae4 --- /dev/null +++ b/dsp/modules/premai.py @@ -0,0 +1,173 @@ +import os +from typing import Any, Optional + +import backoff + +from dsp.modules.lm import LM + +try: + import premai + + premai_api_error = premai.errors.UnexpectedStatus +except ImportError: + premai_api_error = Exception +except AttributeError: + premai_api_error = Exception + + +def backoff_hdlr(details) -> None: + """Handler for the backoff package. + + See more at: https://pypi.org/project/backoff/ + """ + print( + "Backing off {wait:0.1f} seconds after {tries} tries calling function {target} with kwargs {kwargs}".format( + **details, + ), + ) + + +def giveup_hdlr(details) -> bool: + """Wrapper function that decides when to give up on retry.""" + if "rate limits" in details.message: + return False + return True + + +def get_premai_api_key(api_key: Optional[str] = None) -> str: + """Retrieve the PreMAI API key from a passed argument or environment variable.""" + api_key = api_key or os.environ.get("PREMAI_API_KEY") + if api_key is None: + raise RuntimeError( + "No API key found. See the quick start guide at https://docs.premai.io/introduction to get your API key.", + ) + return api_key + + +class PremAI(LM): + """Wrapper around Prem AI's API.""" + + def __init__( + self, + project_id: int, + model: Optional[str] = None, + api_key: Optional[str] = None, + session_id: Optional[int] = None, + **kwargs, + ) -> None: + """Parameters + + project_id: int + "The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/" + model: Optional[str] + The name of model deployed on launchpad. When None, it will show 'default' + api_key: Optional[str] + Prem AI API key, to connect with the API. If not provided then it will check from env var by the name + PREMAI_API_KEY + session_id: Optional[int] + The ID of the session to use. It helps to track the chat history. + **kwargs: dict + Additional arguments to pass to the API provider + """ + model = "default" if model is None else model + super().__init__(model) + if premai_api_error == Exception: + raise ImportError( + "Not loading Prem AI because it is not installed. Install it with `pip install premai`.", + ) + self.kwargs = kwargs if kwargs == {} else self.kwargs + + self.project_id = project_id + self.session_id = session_id + + api_key = get_premai_api_key(api_key=api_key) + self.client = premai.Prem(api_key=api_key) + self.provider = "premai" + self.history: list[dict[str, Any]] = [] + + self.kwargs = { + "temperature": 0.17, + "max_tokens": 150, + **kwargs, + } + if session_id is not None: + self.kwargs["session_id"] = session_id + + # However this is not recommended to change the model once + # deployed from launchpad + + if model != "default": + self.kwargs["model"] = model + + def _get_all_kwargs(self, **kwargs) -> dict: + other_kwargs = { + "seed": None, + "logit_bias": None, + "tools": None, + "system_prompt": None, + } + all_kwargs = { + **self.kwargs, + **other_kwargs, + **kwargs, + } + + _keys_that_cannot_be_none = [ + "system_prompt", + "frequency_penalty", + "presence_penalty", + "tools", + ] + + for key in _keys_that_cannot_be_none: + if all_kwargs.get(key) is None: + all_kwargs.pop(key, None) + return all_kwargs + + def basic_request(self, prompt, **kwargs) -> str: + """Handles retrieval of completions from Prem AI whilst handling API errors.""" + all_kwargs = self._get_all_kwargs(**kwargs) + messages = [] + + if "system_prompt" in all_kwargs: + messages.append({"role": "system", "content": all_kwargs["system_prompt"]}) + messages.append({"role": "user", "content": prompt}) + + response = self.client.chat.completions.create( + project_id=self.project_id, + messages=messages, + **all_kwargs, + ) + if not response.choices: + raise premai_api_error("ChatResponse must have at least one candidate") + + content = response.choices[0].message.content + if not content: + raise premai_api_error("ChatResponse is none") + + output_text = content or "" + + self.history.append( + { + "prompt": prompt, + "response": content, + "kwargs": all_kwargs, + "raw_kwargs": kwargs, + }, + ) + + return output_text + + @backoff.on_exception( + backoff.expo, + (premai_api_error), + max_time=1000, + on_backoff=backoff_hdlr, + giveup=giveup_hdlr, + ) + def request(self, prompt, **kwargs) -> str: + """Handles retrieval of completions from Prem AI whilst handling API errors.""" + return self.basic_request(prompt=prompt, **kwargs) + + def __call__(self, prompt, **kwargs): + return self.request(prompt, **kwargs) diff --git a/dspy/__init__.py b/dspy/__init__.py index 653e3cbcb4..2c6c1d7d45 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -45,6 +45,7 @@ AWSMeta = dsp.AWSMeta Watsonx = dsp.Watsonx +PremAI = dsp.PremAI configure = settings.configure context = settings.context