Skip to content
Merged
70 changes: 70 additions & 0 deletions docs/api/language_model_clients/PremAI.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
---
sidebar_position: 5
---

# dsp.PremAI

[PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application.

### Prerequisites

Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key.

### Usage

Please make sure you have premai python sdk installed. Otherwise you can do it using this command:

```bash
pip install -U premai
```

Here is a quick example on how to use premai python sdk with dspy

```python
from dspy import PremAI

llm = PremAI(model='mistral-tiny', project_id=123, api_key="your-premai-api-key")
print(llm("what is a large language model"))
```

> Please note: Project ID 123 is just an example. You can find your project ID inside our platform under which you created your project.

### Constructor

The constructor initializes the base class `LM` and verifies the `api_key` provided or defined through the `PREMAI_API_KEY` environment variable.

```python
class PremAI(LM):
def __init__(
self,
model: str,
project_id: int,
api_key: str,
base_url: Optional[str] = None,
session_id: Optional[int] = None,
**kwargs,
) -> None:
```

**Parameters:**

- `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad).
- `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice.
- `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None.
- `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history.
- `**kwargs`: Additional language model arguments will be passed to the API provider.

### Methods

#### `__call__(self, prompt: str, **kwargs) -> List[Dict[str, Any]]`

Retrieves completions from PremAI by calling `request`.

Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.

After generation, the completions are post-processed based on the `model_type` parameter.

**Parameters:**

- `prompt` (_str_): Prompt to send to PremAI.
- `**kwargs`: Additional keyword arguments for completion request. Example: parameters like `temperature`, `max_tokens` etc. You can find all the additional kwargs [here](https://docs.premai.io/get-started/sdk#optional-parameters).
3 changes: 2 additions & 1 deletion docs/docs/building-blocks/1-language_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ lm = dspy.{provider_listed_below}(model="your model", model_request_kwargs="..."

4. `dspy.Together` for hosted various open source models.

5. `dspy.PremAI` for hosted best open source and closed source models.

### Local LMs.

Expand Down Expand Up @@ -173,4 +174,4 @@ model = 'dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1'
model_path = 'dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so'

llama = dspy.ChatModuleClient(model=model, model_path=model_path)
```
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
## PremAI

[PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application.

### Prerequisites

Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key.

### Setting up the PremAI Client

The constructor initializes the base class `LM` to support prompting requests to supported PremAI hosted models. This requires the following parameters:

- `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad).
- `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice.
- `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None.
- `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history.
- `**kwargs`: Additional language model arguments will be passed to the API provider.

Example of PremAI constructor:

```python
class PremAI(LM):
def __init__(
self,
model: str,
project_id: int,
api_key: str,
base_url: Optional[str] = None,
session_id: Optional[int] = None,
**kwargs,
) -> None:
```

### Under the Hood

#### `__call__(self, prompt: str, **kwargs) -> str`

**Parameters:**
- `prompt` (_str_): Prompt to send to PremAI.
- `**kwargs`: Additional keyword arguments for completion request.

**Returns:**
- `str`: Completions string from the chosen LLM provider

Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.

### Using the PremAI client

```python
premai_client = dspy.PremAI(project_id=1111)
```

Please note that, this is a dummy `project_id`. You need to change this to the project_id you are interested to use with dspy.

```python
dspy.configure(lm=premai_client)

#Example DSPy CoT QA program
qa = dspy.ChainOfThought('question -> answer')

response = qa(question="What is the capital of Paris?")
print(response.answer)
```

2) Generate responses using the client directly.

```python
response = premai_client(prompt='What is the capital of Paris?')
print(response)
```
1 change: 1 addition & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .hf_client import Anyscale, HFClientTGI, Together
from .mistral import *
from .ollama import *
from .premai import PremAI
from .pyserini import *
from .sbert import *
from .sentence_vectorizer import *
Expand Down
15 changes: 11 additions & 4 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,15 @@ def inspect_history(self, n: int = 1, skip: int = 0):
or provider == "groq"
or provider == "Bedrock"
or provider == "Sagemaker"
or provider == "premai"
):
printed.append((prompt, x["response"]))
elif provider == "anthropic":
blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"]
blocks = [
{"text": block.text}
for block in x["response"].content
if block.type == "text"
]
printed.append((prompt, blocks))
elif provider == "cohere":
printed.append((prompt, x["response"].text))
Expand Down Expand Up @@ -85,7 +90,7 @@ def inspect_history(self, n: int = 1, skip: int = 0):
if provider == "cohere" or provider == "Bedrock" or provider == "Sagemaker":
text = choices
elif provider == "openai" or provider == "ollama":
text = ' ' + self._get_choice_text(choices[0]).strip()
text = " " + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai" or provider == "claude":
text = choices
elif provider == "groq":
Expand All @@ -96,14 +101,16 @@ def inspect_history(self, n: int = 1, skip: int = 0):
text = choices[0].message.content
elif provider == "cloudflare":
text = choices[0]
elif provider == "ibm":
elif provider == "ibm" or provider == "premai":
text = choices
else:
text = choices[0]["text"]
printing_value += self.print_green(text, end="")

if len(choices) > 1 and isinstance(choices, list):
printing_value += self.print_red(f" \t (and {len(choices)-1} other completions)", end="")
printing_value += self.print_red(
f" \t (and {len(choices)-1} other completions)", end="",
)

printing_value += "\n\n\n"

Expand Down
173 changes: 173 additions & 0 deletions dsp/modules/premai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
from typing import Any, Optional

import backoff

from dsp.modules.lm import LM

try:
import premai

premai_api_error = premai.errors.UnexpectedStatus
except ImportError:
premai_api_error = Exception
except AttributeError:
premai_api_error = Exception


def backoff_hdlr(details) -> None:
"""Handler for the backoff package.

See more at: https://pypi.org/project/backoff/
"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries calling function {target} with kwargs {kwargs}".format(
**details,
),
)


def giveup_hdlr(details) -> bool:
"""Wrapper function that decides when to give up on retry."""
if "rate limits" in details.message:
return False
return True


def get_premai_api_key(api_key: Optional[str] = None) -> str:
"""Retrieve the PreMAI API key from a passed argument or environment variable."""
api_key = api_key or os.environ.get("PREMAI_API_KEY")
if api_key is None:
raise RuntimeError(
"No API key found. See the quick start guide at https://docs.premai.io/introduction to get your API key.",
)
return api_key


class PremAI(LM):
"""Wrapper around Prem AI's API."""

def __init__(
self,
project_id: int,
model: Optional[str] = None,
api_key: Optional[str] = None,
session_id: Optional[int] = None,
**kwargs,
) -> None:
"""Parameters

project_id: int
"The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/"
model: Optional[str]
The name of model deployed on launchpad. When None, it will show 'default'
api_key: Optional[str]
Prem AI API key, to connect with the API. If not provided then it will check from env var by the name
PREMAI_API_KEY
session_id: Optional[int]
The ID of the session to use. It helps to track the chat history.
**kwargs: dict
Additional arguments to pass to the API provider
"""
model = "default" if model is None else model
super().__init__(model)
if premai_api_error == Exception:
raise ImportError(
"Not loading Prem AI because it is not installed. Install it with `pip install premai`.",
)
self.kwargs = kwargs if kwargs == {} else self.kwargs

self.project_id = project_id
self.session_id = session_id

api_key = get_premai_api_key(api_key=api_key)
self.client = premai.Prem(api_key=api_key)
self.provider = "premai"
self.history: list[dict[str, Any]] = []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing the line self.history: list[dict[str, Any]] = []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out, but I wonder, self.history is also been initialized in the abstract base class, so why we also define it here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defining the history here gives some flexibility with the typing (that is not done in the LM class). This behavior is maintained in most of the other LM classes (gpt3.py, Ollama, etc.).

Feel free to test the PremAI behavior with lm.inspect_history(n=1) on your end (where lm is defined as your dspy.PremAI model) as well.

self.kwargs = {
"temperature": 0.17,
"max_tokens": 150,
**kwargs,
}
if session_id is not None:
self.kwargs["session_id"] = session_id

# However this is not recommended to change the model once
# deployed from launchpad

if model != "default":
self.kwargs["model"] = model

def _get_all_kwargs(self, **kwargs) -> dict:
other_kwargs = {
"seed": None,
"logit_bias": None,
"tools": None,
"system_prompt": None,
}
all_kwargs = {
**self.kwargs,
**other_kwargs,
**kwargs,
}

_keys_that_cannot_be_none = [
"system_prompt",
"frequency_penalty",
"presence_penalty",
"tools",
]

for key in _keys_that_cannot_be_none:
if all_kwargs.get(key) is None:
all_kwargs.pop(key, None)
return all_kwargs

def basic_request(self, prompt, **kwargs) -> str:
"""Handles retrieval of completions from Prem AI whilst handling API errors."""
all_kwargs = self._get_all_kwargs(**kwargs)
messages = []

if "system_prompt" in all_kwargs:
messages.append({"role": "system", "content": all_kwargs["system_prompt"]})
messages.append({"role": "user", "content": prompt})

response = self.client.chat.completions.create(
project_id=self.project_id,
messages=messages,
**all_kwargs,
)
if not response.choices:
raise premai_api_error("ChatResponse must have at least one candidate")

content = response.choices[0].message.content
if not content:
raise premai_api_error("ChatResponse is none")

output_text = content or ""

self.history.append(
{
"prompt": prompt,
"response": content,
"kwargs": all_kwargs,
"raw_kwargs": kwargs,
},
)

return output_text

@backoff.on_exception(
backoff.expo,
(premai_api_error),
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def request(self, prompt, **kwargs) -> str:
"""Handles retrieval of completions from Prem AI whilst handling API errors."""
return self.basic_request(prompt=prompt, **kwargs)

def __call__(self, prompt, **kwargs):
return self.request(prompt, **kwargs)
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AWSMeta = dsp.AWSMeta

Watsonx = dsp.Watsonx
PremAI = dsp.PremAI

configure = settings.configure
context = settings.context