-
Notifications
You must be signed in to change notification settings - Fork 2.4k
feat(dspy) PremAI python sdk #1007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9b021f0
74944b3
b0167f2
7a605d6
0df0730
852e3bf
86f5470
47f9199
62eaf88
d831ad1
4044d38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| --- | ||
| sidebar_position: 5 | ||
| --- | ||
|
|
||
| # dsp.PremAI | ||
|
|
||
| [PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application. | ||
|
|
||
| ### Prerequisites | ||
|
|
||
| Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key. | ||
|
|
||
| ### Usage | ||
|
|
||
| Please make sure you have premai python sdk installed. Otherwise you can do it using this command: | ||
|
|
||
| ```bash | ||
| pip install -U premai | ||
| ``` | ||
|
|
||
| Here is a quick example on how to use premai python sdk with dspy | ||
|
|
||
| ```python | ||
| from dspy import PremAI | ||
|
|
||
| llm = PremAI(model='mistral-tiny', project_id=123, api_key="your-premai-api-key") | ||
| print(llm("what is a large language model")) | ||
| ``` | ||
|
|
||
| > Please note: Project ID 123 is just an example. You can find your project ID inside our platform under which you created your project. | ||
|
|
||
| ### Constructor | ||
|
|
||
| The constructor initializes the base class `LM` and verifies the `api_key` provided or defined through the `PREMAI_API_KEY` environment variable. | ||
|
|
||
| ```python | ||
| class PremAI(LM): | ||
| def __init__( | ||
| self, | ||
| model: str, | ||
| project_id: int, | ||
| api_key: str, | ||
| base_url: Optional[str] = None, | ||
| session_id: Optional[int] = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| ``` | ||
|
|
||
| **Parameters:** | ||
|
|
||
| - `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad). | ||
| - `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice. | ||
| - `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None. | ||
| - `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history. | ||
| - `**kwargs`: Additional language model arguments will be passed to the API provider. | ||
|
|
||
| ### Methods | ||
|
|
||
| #### `__call__(self, prompt: str, **kwargs) -> List[Dict[str, Any]]` | ||
|
|
||
| Retrieves completions from PremAI by calling `request`. | ||
|
|
||
| Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response. | ||
|
|
||
| After generation, the completions are post-processed based on the `model_type` parameter. | ||
|
|
||
| **Parameters:** | ||
|
|
||
| - `prompt` (_str_): Prompt to send to PremAI. | ||
| - `**kwargs`: Additional keyword arguments for completion request. Example: parameters like `temperature`, `max_tokens` etc. You can find all the additional kwargs [here](https://docs.premai.io/get-started/sdk#optional-parameters). |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| ## PremAI | ||
|
|
||
| [PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application. | ||
|
|
||
| ### Prerequisites | ||
|
|
||
| Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key. | ||
|
|
||
| ### Setting up the PremAI Client | ||
|
|
||
| The constructor initializes the base class `LM` to support prompting requests to supported PremAI hosted models. This requires the following parameters: | ||
|
|
||
| - `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad). | ||
| - `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice. | ||
| - `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None. | ||
| - `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history. | ||
| - `**kwargs`: Additional language model arguments will be passed to the API provider. | ||
|
|
||
| Example of PremAI constructor: | ||
|
|
||
| ```python | ||
| class PremAI(LM): | ||
| def __init__( | ||
| self, | ||
| model: str, | ||
| project_id: int, | ||
| api_key: str, | ||
| base_url: Optional[str] = None, | ||
| session_id: Optional[int] = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| ``` | ||
|
|
||
| ### Under the Hood | ||
|
|
||
| #### `__call__(self, prompt: str, **kwargs) -> str` | ||
|
|
||
| **Parameters:** | ||
| - `prompt` (_str_): Prompt to send to PremAI. | ||
| - `**kwargs`: Additional keyword arguments for completion request. | ||
|
|
||
| **Returns:** | ||
| - `str`: Completions string from the chosen LLM provider | ||
|
|
||
| Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response. | ||
|
|
||
| ### Using the PremAI client | ||
|
|
||
| ```python | ||
| premai_client = dspy.PremAI(project_id=1111) | ||
| ``` | ||
|
|
||
| Please note that, this is a dummy `project_id`. You need to change this to the project_id you are interested to use with dspy. | ||
|
|
||
| ```python | ||
| dspy.configure(lm=premai_client) | ||
|
|
||
| #Example DSPy CoT QA program | ||
| qa = dspy.ChainOfThought('question -> answer') | ||
|
|
||
| response = qa(question="What is the capital of Paris?") | ||
| print(response.answer) | ||
| ``` | ||
|
|
||
| 2) Generate responses using the client directly. | ||
|
|
||
| ```python | ||
| response = premai_client(prompt='What is the capital of Paris?') | ||
| print(response) | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| import os | ||
| from typing import Any, Optional | ||
|
|
||
| import backoff | ||
|
|
||
| from dsp.modules.lm import LM | ||
|
|
||
| try: | ||
| import premai | ||
|
|
||
| premai_api_error = premai.errors.UnexpectedStatus | ||
| except ImportError: | ||
| premai_api_error = Exception | ||
| except AttributeError: | ||
| premai_api_error = Exception | ||
|
|
||
|
|
||
| def backoff_hdlr(details) -> None: | ||
| """Handler for the backoff package. | ||
|
|
||
| See more at: https://pypi.org/project/backoff/ | ||
| """ | ||
| print( | ||
| "Backing off {wait:0.1f} seconds after {tries} tries calling function {target} with kwargs {kwargs}".format( | ||
| **details, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def giveup_hdlr(details) -> bool: | ||
| """Wrapper function that decides when to give up on retry.""" | ||
| if "rate limits" in details.message: | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def get_premai_api_key(api_key: Optional[str] = None) -> str: | ||
| """Retrieve the PreMAI API key from a passed argument or environment variable.""" | ||
| api_key = api_key or os.environ.get("PREMAI_API_KEY") | ||
| if api_key is None: | ||
| raise RuntimeError( | ||
| "No API key found. See the quick start guide at https://docs.premai.io/introduction to get your API key.", | ||
| ) | ||
| return api_key | ||
|
|
||
|
|
||
| class PremAI(LM): | ||
| """Wrapper around Prem AI's API.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| project_id: int, | ||
| model: Optional[str] = None, | ||
| api_key: Optional[str] = None, | ||
| session_id: Optional[int] = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| """Parameters | ||
|
|
||
| project_id: int | ||
| "The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/" | ||
| model: Optional[str] | ||
| The name of model deployed on launchpad. When None, it will show 'default' | ||
| api_key: Optional[str] | ||
| Prem AI API key, to connect with the API. If not provided then it will check from env var by the name | ||
| PREMAI_API_KEY | ||
| session_id: Optional[int] | ||
| The ID of the session to use. It helps to track the chat history. | ||
| **kwargs: dict | ||
| Additional arguments to pass to the API provider | ||
| """ | ||
| model = "default" if model is None else model | ||
| super().__init__(model) | ||
| if premai_api_error == Exception: | ||
| raise ImportError( | ||
| "Not loading Prem AI because it is not installed. Install it with `pip install premai`.", | ||
| ) | ||
| self.kwargs = kwargs if kwargs == {} else self.kwargs | ||
|
|
||
| self.project_id = project_id | ||
| self.session_id = session_id | ||
|
|
||
| api_key = get_premai_api_key(api_key=api_key) | ||
| self.client = premai.Prem(api_key=api_key) | ||
| self.provider = "premai" | ||
| self.history: list[dict[str, Any]] = [] | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing the line
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out, but I wonder,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. defining the history here gives some flexibility with the typing (that is not done in the LM class). This behavior is maintained in most of the other LM classes (gpt3.py, Ollama, etc.). Feel free to test the PremAI behavior with lm.inspect_history(n=1) on your end (where lm is defined as your dspy.PremAI model) as well. |
||
| self.kwargs = { | ||
| "temperature": 0.17, | ||
| "max_tokens": 150, | ||
| **kwargs, | ||
| } | ||
| if session_id is not None: | ||
| self.kwargs["session_id"] = session_id | ||
|
|
||
| # However this is not recommended to change the model once | ||
| # deployed from launchpad | ||
|
|
||
| if model != "default": | ||
| self.kwargs["model"] = model | ||
|
|
||
| def _get_all_kwargs(self, **kwargs) -> dict: | ||
| other_kwargs = { | ||
| "seed": None, | ||
| "logit_bias": None, | ||
| "tools": None, | ||
| "system_prompt": None, | ||
| } | ||
| all_kwargs = { | ||
| **self.kwargs, | ||
| **other_kwargs, | ||
| **kwargs, | ||
| } | ||
|
|
||
| _keys_that_cannot_be_none = [ | ||
| "system_prompt", | ||
| "frequency_penalty", | ||
| "presence_penalty", | ||
| "tools", | ||
| ] | ||
|
|
||
| for key in _keys_that_cannot_be_none: | ||
| if all_kwargs.get(key) is None: | ||
| all_kwargs.pop(key, None) | ||
| return all_kwargs | ||
|
|
||
| def basic_request(self, prompt, **kwargs) -> str: | ||
| """Handles retrieval of completions from Prem AI whilst handling API errors.""" | ||
| all_kwargs = self._get_all_kwargs(**kwargs) | ||
| messages = [] | ||
|
|
||
| if "system_prompt" in all_kwargs: | ||
| messages.append({"role": "system", "content": all_kwargs["system_prompt"]}) | ||
| messages.append({"role": "user", "content": prompt}) | ||
|
|
||
| response = self.client.chat.completions.create( | ||
| project_id=self.project_id, | ||
| messages=messages, | ||
| **all_kwargs, | ||
| ) | ||
| if not response.choices: | ||
| raise premai_api_error("ChatResponse must have at least one candidate") | ||
|
|
||
| content = response.choices[0].message.content | ||
| if not content: | ||
| raise premai_api_error("ChatResponse is none") | ||
|
|
||
| output_text = content or "" | ||
|
|
||
| self.history.append( | ||
| { | ||
| "prompt": prompt, | ||
| "response": content, | ||
| "kwargs": all_kwargs, | ||
| "raw_kwargs": kwargs, | ||
| }, | ||
| ) | ||
|
|
||
| return output_text | ||
|
|
||
| @backoff.on_exception( | ||
| backoff.expo, | ||
| (premai_api_error), | ||
| max_time=1000, | ||
| on_backoff=backoff_hdlr, | ||
| giveup=giveup_hdlr, | ||
| ) | ||
| def request(self, prompt, **kwargs) -> str: | ||
| """Handles retrieval of completions from Prem AI whilst handling API errors.""" | ||
| return self.basic_request(prompt=prompt, **kwargs) | ||
|
|
||
| def __call__(self, prompt, **kwargs): | ||
| return self.request(prompt, **kwargs) | ||
Uh oh!
There was an error while loading. Please reload this page.