Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions docs/api/language_model_clients/aws_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
---
sidebar_position: 9
---

# dspy.AWSMistral, dspy.AWSAnthropic, dspy.AWSMeta

### Usage

```python
# Notes:
# 1. Install boto3 to use AWS models.
# 2. Configure your AWS credentials with the AWS CLI before using these models

# initialize the bedrock aws provider
bedrock = dspy.Bedrock(region_name="us-west-2")
# For mixtral on Bedrock
lm = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs)
# For haiku on Bedrock
lm = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs)
# For llama2 on Bedrock
lm = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs)

# initialize the sagemaker aws provider
sagemaker = dspy.Sagemaker(region_name="us-west-2")
# For mistral on Sagemaker
# Note: you need to create a Sagemaker endpoint for the mistral model first
lm = dspy.AWSMistral(sagemaker, "<YOUR_MISTRAL_ENDPOINT_NAME>", **kwargs)

```

### Constructor

The `AWSMistral` constructor initializes the base class `AWSModel` which itself inherits from the `LM` class.

```python
class AWSMistral(AWSModel):
"""Mistral family of models."""

def __init__(
self,
aws_provider: AWSProvider,
model: str,
max_context_size: int = 32768,
max_new_tokens: int = 1500,
**kwargs
) -> None:
```

**Parameters:**
- `aws_provider` (AWSProvider): The aws provider to use. One of `dspy.Bedrock` or `dspy.Sagemaker`.
- `model` (_str_): Mistral AI pretrained models. For Bedrock, this is the Model ID in https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns. For Sagemaker, this is the endpoint name.
- `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768.
- `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500.
- `**kwargs`: Additional language model arguments to pass to the API provider.

### Methods

```python
def _format_prompt(self, raw_prompt: str) -> str:
```
This function formats the prompt for the model. Refer to the model card for the specific formatting required.

<br/>

```python
def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]:
```
This function creates the body of the request to the model. It takes the prompt and any additional keyword arguments and returns a tuple of the number of tokens to generate and a dictionary of keys including the prompt used to create the body of the request.

<br/>

```python
def _call_model(self, body: str) -> str:
```
This function calls the model using the provider `call_model()` function and extracts the generated text (completion) from the provider-specific response.

<br/>

The above model-specific methods are called by the `AWSModel::basic_request()` method, which is the main method for querying the model. This method takes the prompt and any additional keyword arguments and calls the `AWSModel::_simple_api_call()` which then delegates to the model-specific `_create_body()` and `_call_model()` methods to create the body of the request, call the model and extract the generated text.


Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation for information on the `LM` base class functionality.

<br/>

`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`.
53 changes: 53 additions & 0 deletions docs/api/language_model_clients/aws_providers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
sidebar_position: 9
---

# dspy.Bedrock, dspy.Sagemaker

### Usage

The `AWSProvider` class is the base class for the AWS providers - `dspy.Bedrock` and `dspy.Sagemaker`. An instance of one of these providers is passed to the constructor when creating an instance of an AWS model class (e.g., `dspy.AWSMistral`) that is ultimately used to query the model.

```python
# Notes:
# 1. Install boto3 to use AWS models.
# 2. Configure your AWS credentials with the AWS CLI before using these models

# initialize the bedrock aws provider
bedrock = dspy.Bedrock(region_name="us-west-2")

# initialize the sagemaker aws provider
sagemaker = dspy.Sagemaker(region_name="us-west-2")
```

### Constructor

The `Bedrock` constructor initializes the base class `AWSProvider`.

```python
class Bedrock(AWSProvider):
"""This class adds support for Bedrock models."""

def __init__(
self,
region_name: str,
profile_name: Optional[str] = None,
batch_n_enabled: bool = False, # This has to be setup manually on Bedrock.
) -> None:
```

**Parameters:**
- `region_name` (str): The AWS region where this LM is hosted.
- `profile_name` (str, optional): boto3 credentials profile.
- `batch_n_enabled` (bool): If False, call the LM N times rather than batching.

### Methods

```python
def call_model(self, model_id: str, body: str) -> str:
```
This function implements the actual invocation of the model on AWS using the boto3 provider.

<br/>

`Sagemaker` works exactly the same as `Bedrock`.
7 changes: 5 additions & 2 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .anthropic import Claude
from .aws_models import AWSAnthropic, AWSMeta, AWSMistral, AWSModel

# Below is obsolete. It has been replaced with Bedrock class in dsp/modules/aws_providers.py
# from .bedrock import *
from .aws_providers import Bedrock, Sagemaker
from .azure_openai import AzureOpenAI
from .bedrock import *
from .cache_utils import *
from .clarifai import *
from .cohere import *
Expand All @@ -17,4 +21,3 @@
from .pyserini import *
from .sbert import *
from .sentence_vectorizer import *

186 changes: 0 additions & 186 deletions dsp/modules/aws_lm.py

This file was deleted.

Loading