From 34774912729c4f78553fe1f1ad7a3729c89bde92 Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Tue, 9 Apr 2024 13:38:55 -0500 Subject: [PATCH 01/15] feature(modules): added support for aws providers (bedrock and sagemaker) and models (mistral, anthropic, meta) --- docs/api/language_model_clients/aws.md | 61 +++++ dsp/modules/__init__.py | 7 +- dsp/modules/aws_models.py | 300 +++++++++++++++++++++++++ dsp/modules/aws_providers.py | 128 +++++++++++ dsp/modules/bedrock.py | 79 ------- dspy/__init__.py | 6 + pyproject.toml | 4 +- tests/modules/test_aws_models.py | 59 +++++ 8 files changed, 562 insertions(+), 82 deletions(-) create mode 100644 docs/api/language_model_clients/aws.md create mode 100644 dsp/modules/aws_models.py create mode 100644 dsp/modules/aws_providers.py delete mode 100644 dsp/modules/bedrock.py create mode 100644 tests/modules/test_aws_models.py diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws.md new file mode 100644 index 0000000000..afac264f3d --- /dev/null +++ b/docs/api/language_model_clients/aws.md @@ -0,0 +1,61 @@ +--- +sidebar_position: 9 +--- + +# dsp.AWSMistral, dsp.AWSAnthropic, dsp.AWSMeta + +### Usage + +```python +# Notes: +# 1. Install boto3 to use AWS models. +# 2. Configure your AWS credentials with the AWS CLI before using these models + +# initialize the bedrock aws provider +bedrock = Bedrock(region_name="us-west-2") +# For mixtral on Bedrock +lm = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) +# For haiku on Bedrock +lm = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) +# For llama2 on Bedrock +lm = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + +# initialize the sagemaker aws provider +sagemaker = Sagemaker(region_name="us-west-2") +# For mistral on Sagemaker +# Note: you need to create a Sagemaker endpoint for the mistral model first +lm = AWSMistral(sagemaker, "", **kwargs) + +``` + +### Constructor + +The constructor initializes the base class `LM` and the `AWSProvider` class. + +```python +class AWSMistral(AWSModel): + """Mistral family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 32768, + max_new_tokens: int = 1500, + **kwargs + ) -> None: +``` + +**Parameters:** +- `aws_provider` (AWSProvider): The aws provider to use. One of `Bedrock` or `Sagemaker`. +- `model` (_str_): Mistral AI pretrained models. Defaults to `mistral-medium-latest`. +- `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768. +- `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500. +- `**kwargs`: Additional language model arguments to pass to the API provider. + +### Methods + +Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation. + + +`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`. \ No newline at end of file diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index b7a739f504..935f65b403 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -1,6 +1,10 @@ from .anthropic import Claude +from .aws_models import AWSAnthropic, AWSMeta, AWSMistral, AWSModel + +# Below is obsolete. It has been replaced with Bedrock class in dsp/modules/aws_providers.py +# from .bedrock import * +from .aws_providers import Bedrock, Sagemaker from .azure_openai import AzureOpenAI -from .bedrock import * from .cache_utils import * from .clarifai import * from .cohere import * @@ -17,4 +21,3 @@ from .pyserini import * from .sbert import * from .sentence_vectorizer import * - diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py new file mode 100644 index 0000000000..3b93948897 --- /dev/null +++ b/dsp/modules/aws_models.py @@ -0,0 +1,300 @@ +"""AWS models for LMs.""" + +from __future__ import annotations + +import json +import logging +from abc import abstractmethod +from typing import Any + +from dsp.modules.aws_providers import AWSProvider, Bedrock, Sagemaker +from dsp.modules.lm import LM + +# Heuristic translating number of chars to tokens +# ~4 chars = 1 token +CHARS2TOKENS: int = 4 + + +class AWSModel(LM): + """This class adds support for an AWS model. + It is an abstract class and should not be instantiated directly. + Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta. + The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker. + Usage Example: + bedrock = Bedrock(region_name="us-west-2") + bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + + sagemaker = Sagemaker(region_name="us-west-2") + sagemaker_model = AWSMistral(sagemaker, "", **kwargs) + """ + + def __init__( + self, + model: str, + max_context_size: int, + max_new_tokens: int, + **kwargs, + ) -> None: + """_summary_. + + Args: + model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. + max_context_size (int): The maximum context size in tokens. + max_new_tokens (int): The maximum number of tokens to be sampled from the LM. + """ + super().__init__(model=model) + self._model_name: str = model + self._max_context_size: int = max_context_size + self._max_new_tokens: int = max_new_tokens + + self.kwargs = { + **self.kwargs, + **kwargs, + } + + @abstractmethod + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + pass + + @abstractmethod + def _call_model(self, body: str) -> str | list[str]: + """Call model, get generated input without the formatted prompt.""" + + def _estimate_tokens(self, text: str) -> int: + return len(text)/CHARS2TOKENS + + def _extract_input_parameters( + self, + body: dict[Any, Any], + ) -> dict[str, str | float | int]: + return body + + def _format_prompt(self, raw_prompt: str) -> str: + return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" + + def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]: + n, body = self._create_body(formatted_prompt, **kwargs) + json_body = json.dumps(body) + + if n > 1: + llm_out = [self._call_model(json_body) for _ in range(n)] + llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out] + else: + llm_out = self._call_model(json_body) + llm_out = llm_out.replace(formatted_prompt, "") + + self.history.append( + {"prompt": formatted_prompt, "response": llm_out, "kwargs": body}, + ) + return llm_out + + def basic_request(self, prompt, **kwargs) -> str | list[str]: + """Query the endpoint.""" + token_count = self._estimate_tokens(prompt) + if token_count > self._max_context_size: + logging.info("Error - input tokens %s exceeds max context %s", token_count, self._max_context_size) + raise ValueError( + f"Error - input tokens {token_count} exceeds max context {self._max_context_size}", + ) + + formatted_prompt: str = self._format_prompt(prompt) + return self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs) + + def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, + ) -> list[str]: + """Query the AWS LLM. + + There is only support for only_completed=True and return_sorted=False + right now. + """ + assert only_completed, "for now" + assert return_sorted is False, "for now" + + generated = self.basic_request(prompt, **kwargs) + return [generated] + + +class AWSMistral(AWSModel): + """Mistral family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 32768, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.aws_provider = aws_provider + self.provider = aws_provider.get_provider_name() + + self.kwargs["stop"] = "\n\n---" + + def _format_prompt(self, raw_prompt: str) -> str: + return " [INST] Human: " + raw_prompt + " [/INST] Assistant: " + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, base_args = self.aws_provider.sanitize_kwargs(base_args) + + query_args: dict[str, str | float] = {} + if isinstance(self.aws_provider, Bedrock): + query_args["prompt"] = prompt + elif isinstance(self.aws_provider, Sagemaker): + query_args["parameters"] = base_args + query_args["inputs"] = prompt + else: + raise ValueError("Error - provider not recognized") + + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.aws_provider.call_model( + model_id=self._model_name, + body=body, + ) + if isinstance(self.aws_provider, Bedrock): + response_body = json.loads(response["body"].read()) + completion = response_body["outputs"][0]["text"] + elif isinstance(self.aws_provider, Sagemaker): + response_body = json.loads(response["Body"].read()) + completion = response_body[0]["generated_text"] + else: + raise ValueError("Error - provider not recognized") + + completion = completion.split(self.kwargs["stop"])[0] + return completion + + +class AWSAnthropic(AWSModel): + """Anthropic family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 200000, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.provider = aws_provider + + for k, v in kwargs.items(): + self.kwargs[k] = v + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, query_args = self.provider.sanitize_kwargs(base_args) + + # Anthropic models do not support the following parameters + query_args.pop("frequency_penalty", None) + query_args.pop("num_generations", None) + query_args.pop("presence_penalty", None) + query_args.pop("model", None) + + # we are using the Claude messages API + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + query_args["messages"] = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + ], + }, + ] + query_args["anthropic_version"] = "bedrock-2023-05-31" + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.provider.predictor.invoke_model( + modelId=self._model_name, + body=body, + ) + response_body = json.loads(response["body"].read()) + completion = response_body["content"][0]["text"] + return completion + + +class AWSMeta(AWSModel): + """Llama2 family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 4096, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.provider = aws_provider + + for k, v in kwargs.items(): + self.kwargs[k] = v + + self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens") + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, query_args = self.provider.sanitize_kwargs(base_args) + + # Meta models do not support the following parameters + query_args.pop("frequency_penalty", None) + query_args.pop("num_generations", None) + query_args.pop("presence_penalty", None) + query_args.pop("model", None) + + query_args["prompt"] = prompt + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.provider.predictor.invoke_model( + modelId=self._model_name, + body=body, + ) + response_body = json.loads(response["body"].read()) + completion = response_body["generation"] + + stop = "\n\n" + completion = completion.split(stop)[0] + + return completion diff --git a/dsp/modules/aws_providers.py b/dsp/modules/aws_providers.py new file mode 100644 index 0000000000..ae3f0d4c3f --- /dev/null +++ b/dsp/modules/aws_providers.py @@ -0,0 +1,128 @@ +"""AWS providers for LMs.""" + +from abc import ABC, abstractmethod +from typing import Any, Optional + + +class AWSProvider(ABC): + """This abstract class adds support for AWS model providers such as Bedrock and SageMaker. + The subclasses such as Bedrock and Sagemaker implement the abstract method _call_model and work in conjunction with the AWSModel classes. + Usage Example: + bedrock = Bedrock(region_name="us-west-2") + bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + + sagemaker = Sagemaker(region_name="us-west-2") + sagemaker_model = AWSMistral(sagemaker, "", **kwargs) + """ + + def __init__( + self, + region_name: str, + service_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = True, + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + service_name (str): Used in context of invoking the boto3 API. + profile_name (str, optional): boto3 credentials profile. + batch_n_enabled (bool): If False, call the LM N times rather than batching. + """ + try: + import boto3 # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError('Please install boto3 to use AWS models.') from exc + + if profile_name is None: + self.predictor = boto3.client(service_name, region_name=region_name) + else: + self.predictor = boto3.Session(profile_name=profile_name).client( + service_name, + region_name=region_name, + ) + + self.batch_n_enabled = batch_n_enabled + + def get_provider_name(self) -> str: + """Return the provider name.""" + return self.__class__.__name__ + + @abstractmethod + def call_model(self, model_id: str, body: str) -> str: + """Call the model and return the response.""" + + def sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> tuple[int, dict[str, Any]]: + """Ensure that input kwargs can be used by Bedrock or Sagemaker.""" + if "temperature" in query_kwargs: + if query_kwargs["temperature"] > 0.99: + query_kwargs["temperature"] = 0.99 + if query_kwargs["temperature"] < 0.01: + query_kwargs["temperature"] = 0.01 + + if "top_p" in query_kwargs: + if query_kwargs["top_p"] > 0.99: + query_kwargs["top_p"] = 0.99 + if query_kwargs["top_p"] < 0.01: + query_kwargs["top_p"] = 0.01 + + n = -1 + if not self.batch_n_enabled: + n = query_kwargs.pop('n', 1) + query_kwargs["num_generations"] = n + + return n, query_kwargs + + +class Bedrock(AWSProvider): + """This class adds support for Bedrock models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = False, # This has to be setup manually on Bedrock. + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + profile_name (str, optional): boto3 credentials profile. + """ + super().__init__(region_name, "bedrock-runtime", profile_name, batch_n_enabled) + + def call_model(self, model_id: str, body: str) -> str: + return self.predictor.invoke_model( + modelId=model_id, + body=body, + accept="application/json", + contentType="application/json", + ) + + +class Sagemaker(AWSProvider): + """This class adds support for Sagemaker models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + profile_name (str, optional): boto3 credentials profile. + """ + super().__init__(region_name, "runtime.sagemaker", profile_name) + + def call_model(self, model_id: str, body: str) -> str: + return self.predictor.invoke_endpoint( + EndpointName=model_id, + Body=body, + Accept="application/json", + ContentType="application/json", + ) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py deleted file mode 100644 index 252c87fe86..0000000000 --- a/dsp/modules/bedrock.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any, Optional - -from dsp.modules.aws_lm import AWSLM - - -class Bedrock(AWSLM): - def __init__( - self, - region_name: str, - model: str, - profile_name: Optional[str] = None, - input_output_ratio: int = 3, - max_new_tokens: int = 1500, - ) -> None: - """Use an AWS Bedrock language model. - NOTE: You must first configure your AWS credentials with the AWS CLI before using this model! - - Args: - region_name (str, optional): The AWS region where this LM is hosted. - model (str, optional): An AWS Bedrock LM name. You can find available models with the AWS CLI as follows: aws bedrock list-foundation-models --query "modelSummaries[*].modelId". - temperature (float, optional): Default temperature for LM. Defaults to 0. - input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. - max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. - """ - super().__init__( - model=model, - service_name="bedrock-runtime", - region_name=region_name, - profile_name=profile_name, - truncate_long_prompts=False, - input_output_ratio=input_output_ratio, - max_new_tokens=max_new_tokens, - batch_n=True, # Bedrock does not support the `n` parameter - ) - self._validate_model(model) - self.provider = "claude" if "claude" in model.lower() else "bedrock" - - def _validate_model(self, model: str) -> None: - if "claude" not in model.lower(): - raise NotImplementedError("Only claude models are supported as of now") - - def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]: - base_args: dict[str, Any] = { - "max_tokens_to_sample": self._max_new_tokens, - } - for k, v in kwargs.items(): - base_args[k] = v - query_args: dict[str, Any] = self._sanitize_kwargs(base_args) - query_args["prompt"] = prompt - # AWS Bedrock forbids these keys - if "max_tokens" in query_args: - max_tokens: int = query_args["max_tokens"] - input_tokens: int = self._estimate_tokens(prompt) - max_tokens_to_sample: int = max_tokens - input_tokens - del query_args["max_tokens"] - query_args["max_tokens_to_sample"] = max_tokens_to_sample - return query_args - - def _call_model(self, body: str) -> str: - response = self.predictor.invoke_model( - modelId=self._model_name, - body=body, - accept="application/json", - contentType="application/json", - ) - response_body = json.loads(response["body"].read()) - completion = response_body["completion"] - return completion - - def _extract_input_parameters( - self, body: dict[Any, Any], - ) -> dict[str, str | float | int]: - return body - - def _format_prompt(self, raw_prompt: str) -> str: - return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" diff --git a/dspy/__init__.py b/dspy/__init__.py index 1c6539f59b..2d5ba4f003 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -33,7 +33,13 @@ Together = dsp.Together HFModel = dsp.HFModel OllamaLocal = dsp.OllamaLocal + Bedrock = dsp.Bedrock +Sagemaker = dsp.Sagemaker +AWSModel = dsp.AWSModel +AWSMistral = dsp.AWSMistral +AWSAnthropic = dsp.AWSAnthropic +AWSMeta = dsp.AWSMeta configure = settings.configure context = settings.context diff --git a/pyproject.toml b/pyproject.toml index 2dafeab07b..f564865f2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ marqo = ["marqo"] pinecone = ["pinecone-client~=2.2.4"] weaviate = ["weaviate-client~=4.5.4"] milvus = ["pymilvus~=2.3.7"] +aws = ["boto3~=1.34.78"] docs = [ "sphinx>=4.3.0", "furo>=2023.3.27", @@ -58,7 +59,6 @@ docs = [ "autodoc_pydantic", "sphinx-reredirects>=0.1.2", "sphinx-automodapi==0.16.0", - ] dev = ["pytest>=6.2.5"] @@ -100,6 +100,7 @@ qdrant-client = { version = "^1.6.2", optional = true } pinecone-client = { version = "^2.2.4", optional = true } weaviate-client = { version = "^4.5.4", optional = true } pymilvus = { version = "^2.3.6", optional = true } +boto3 = { version = "^1.34.78", optional = true } sphinx = { version = ">=4.3.0", optional = true } furo = { version = ">=2023.3.27", optional = true } docutils = { version = "<0.17", optional = true } @@ -134,6 +135,7 @@ marqo = ["marqo"] pinecone = ["pinecone-client"] weaviate = ["weaviate-client"] milvus = ["pymilvus"] +aws = ["boto3"] postgres = ["psycopg2", "pgvector"] docs = [ "sphinx", diff --git a/tests/modules/test_aws_models.py b/tests/modules/test_aws_models.py new file mode 100644 index 0000000000..89cedcc6f7 --- /dev/null +++ b/tests/modules/test_aws_models.py @@ -0,0 +1,59 @@ +"""Tests for AWS models. +Note: Requires configuration of your AWS credentials with the AWS CLI and creating sagemaker endpoints. +TODO: Create mock fixtures for pytest to remove the need for AWS credentials and endpoints. +""" + +import dsp +import dspy + +def get_lm(lm_provider: str, model_path: str, **kwargs) -> dsp.modules.lm.LM: + """get the language model""" + # extract model vendor and name from model name + # Model path format is / + model_vendor = model_path.split('/')[0] + model_name = model_path.split('/')[1] + + if lm_provider == 'Bedrock': + bedrock = dspy.Bedrock(region_name="us-west-2") + if model_vendor == 'mistral': + return dspy.AWSMistral(bedrock, model_name, **kwargs) + elif model_vendor == 'anthropic': + return dspy.AWSAnthropic(bedrock, model_name, **kwargs) + elif model_vendor == 'meta': + return dspy.AWSMeta(bedrock, model_name, **kwargs) + else: + raise ValueError("Model vendor missing or unsupported: Model path format is /") + elif lm_provider == 'Sagemaker': + sagemaker = dspy.Sagemaker(region_name="us-west-2") + if model_vendor == 'mistral': + return dspy.AWSMistral(sagemaker, model_name, **kwargs) + elif model_vendor == 'meta': + return dspy.AWSMeta(sagemaker, model_name, **kwargs) + else: + raise ValueError("Model vendor missing or unsupported: Model path format is /") + else: + raise ValueError(f"Unsupported model: {model_name}") + +def run_tests(): + """Test the providers and models""" + # Configure your AWS credentials with the AWS CLI before running this script + provider_model_tuples = [ + ('Bedrock', 'mistral/mistral.mixtral-8x7b-instruct-v0:1'), + ('Bedrock', 'anthropic/anthropic.claude-3-haiku-20240307-v1:0'), + ('Bedrock', 'anthropic/anthropic.claude-3-sonnet-20240229-v1:0'), + ('Bedrock', 'meta/meta.llama2-70b-chat-v1'), + # ('Sagemaker', 'mistral/'), # REPLACE YOUR_ENDPOINT_NAME with your sagemaker endpoint + ] + + predict_func = dspy.Predict("question -> answer") + for provider, model_path in provider_model_tuples: + print(f"Provider: {provider}, Model: {model_path}") + lm = get_lm(provider, model_path) + with dspy.context(lm=lm): + question = "What is the capital of France?" + answer = predict_func(question=question).answer + print(f"Question: {question}\nAnswer: {answer}\n\n") + + +if __name__ == "__main__": + run_tests() From d19205741491d084ab47856831ac88d1fc9ecf87 Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Tue, 9 Apr 2024 13:47:22 -0500 Subject: [PATCH 02/15] feature(modules): support inspect_history for aws models --- dsp/modules/lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index 4af1dc80ec..9364f38e03 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -46,7 +46,7 @@ def inspect_history(self, n: int = 1, skip: int = 0): prompt = x["prompt"] if prompt != last_prompt: - if provider == "clarifai" or provider == "google": + if provider == "clarifai" or provider == "google" or provider == "Bedrock" or provider == "Sagemaker": printed.append((prompt, x["response"])) elif provider == "anthropic": blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"] @@ -74,7 +74,7 @@ def inspect_history(self, n: int = 1, skip: int = 0): printing_value += prompt text = "" - if provider == "cohere": + if provider == "cohere" or provider == "Bedrock" or provider == "Sagemaker": text = choices elif provider == "openai" or provider == "ollama": text = ' ' + self._get_choice_text(choices[0]).strip() From 98abac2c1c9b9a67c0a8d4cc1d1f52db57cc0df1 Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Tue, 9 Apr 2024 14:01:14 -0500 Subject: [PATCH 03/15] Revert "feature(modules): added support for aws providers (bedrock and sagemaker) and models (mistral, anthropic, meta)" This reverts commit 34774912729c4f78553fe1f1ad7a3729c89bde92. --- docs/api/language_model_clients/aws.md | 61 ----- dsp/modules/__init__.py | 7 +- dsp/modules/aws_models.py | 300 ------------------------- dsp/modules/aws_providers.py | 128 ----------- dsp/modules/bedrock.py | 79 +++++++ dspy/__init__.py | 6 - pyproject.toml | 4 +- tests/modules/test_aws_models.py | 59 ----- 8 files changed, 82 insertions(+), 562 deletions(-) delete mode 100644 docs/api/language_model_clients/aws.md delete mode 100644 dsp/modules/aws_models.py delete mode 100644 dsp/modules/aws_providers.py create mode 100644 dsp/modules/bedrock.py delete mode 100644 tests/modules/test_aws_models.py diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws.md deleted file mode 100644 index afac264f3d..0000000000 --- a/docs/api/language_model_clients/aws.md +++ /dev/null @@ -1,61 +0,0 @@ ---- -sidebar_position: 9 ---- - -# dsp.AWSMistral, dsp.AWSAnthropic, dsp.AWSMeta - -### Usage - -```python -# Notes: -# 1. Install boto3 to use AWS models. -# 2. Configure your AWS credentials with the AWS CLI before using these models - -# initialize the bedrock aws provider -bedrock = Bedrock(region_name="us-west-2") -# For mixtral on Bedrock -lm = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) -# For haiku on Bedrock -lm = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) -# For llama2 on Bedrock -lm = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) - -# initialize the sagemaker aws provider -sagemaker = Sagemaker(region_name="us-west-2") -# For mistral on Sagemaker -# Note: you need to create a Sagemaker endpoint for the mistral model first -lm = AWSMistral(sagemaker, "", **kwargs) - -``` - -### Constructor - -The constructor initializes the base class `LM` and the `AWSProvider` class. - -```python -class AWSMistral(AWSModel): - """Mistral family of models.""" - - def __init__( - self, - aws_provider: AWSProvider, - model: str, - max_context_size: int = 32768, - max_new_tokens: int = 1500, - **kwargs - ) -> None: -``` - -**Parameters:** -- `aws_provider` (AWSProvider): The aws provider to use. One of `Bedrock` or `Sagemaker`. -- `model` (_str_): Mistral AI pretrained models. Defaults to `mistral-medium-latest`. -- `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768. -- `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500. -- `**kwargs`: Additional language model arguments to pass to the API provider. - -### Methods - -Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation. - - -`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`. \ No newline at end of file diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index 935f65b403..b7a739f504 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -1,10 +1,6 @@ from .anthropic import Claude -from .aws_models import AWSAnthropic, AWSMeta, AWSMistral, AWSModel - -# Below is obsolete. It has been replaced with Bedrock class in dsp/modules/aws_providers.py -# from .bedrock import * -from .aws_providers import Bedrock, Sagemaker from .azure_openai import AzureOpenAI +from .bedrock import * from .cache_utils import * from .clarifai import * from .cohere import * @@ -21,3 +17,4 @@ from .pyserini import * from .sbert import * from .sentence_vectorizer import * + diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py deleted file mode 100644 index 3b93948897..0000000000 --- a/dsp/modules/aws_models.py +++ /dev/null @@ -1,300 +0,0 @@ -"""AWS models for LMs.""" - -from __future__ import annotations - -import json -import logging -from abc import abstractmethod -from typing import Any - -from dsp.modules.aws_providers import AWSProvider, Bedrock, Sagemaker -from dsp.modules.lm import LM - -# Heuristic translating number of chars to tokens -# ~4 chars = 1 token -CHARS2TOKENS: int = 4 - - -class AWSModel(LM): - """This class adds support for an AWS model. - It is an abstract class and should not be instantiated directly. - Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta. - The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker. - Usage Example: - bedrock = Bedrock(region_name="us-west-2") - bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) - bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) - bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) - - sagemaker = Sagemaker(region_name="us-west-2") - sagemaker_model = AWSMistral(sagemaker, "", **kwargs) - """ - - def __init__( - self, - model: str, - max_context_size: int, - max_new_tokens: int, - **kwargs, - ) -> None: - """_summary_. - - Args: - model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. - max_context_size (int): The maximum context size in tokens. - max_new_tokens (int): The maximum number of tokens to be sampled from the LM. - """ - super().__init__(model=model) - self._model_name: str = model - self._max_context_size: int = max_context_size - self._max_new_tokens: int = max_new_tokens - - self.kwargs = { - **self.kwargs, - **kwargs, - } - - @abstractmethod - def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - pass - - @abstractmethod - def _call_model(self, body: str) -> str | list[str]: - """Call model, get generated input without the formatted prompt.""" - - def _estimate_tokens(self, text: str) -> int: - return len(text)/CHARS2TOKENS - - def _extract_input_parameters( - self, - body: dict[Any, Any], - ) -> dict[str, str | float | int]: - return body - - def _format_prompt(self, raw_prompt: str) -> str: - return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" - - def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]: - n, body = self._create_body(formatted_prompt, **kwargs) - json_body = json.dumps(body) - - if n > 1: - llm_out = [self._call_model(json_body) for _ in range(n)] - llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out] - else: - llm_out = self._call_model(json_body) - llm_out = llm_out.replace(formatted_prompt, "") - - self.history.append( - {"prompt": formatted_prompt, "response": llm_out, "kwargs": body}, - ) - return llm_out - - def basic_request(self, prompt, **kwargs) -> str | list[str]: - """Query the endpoint.""" - token_count = self._estimate_tokens(prompt) - if token_count > self._max_context_size: - logging.info("Error - input tokens %s exceeds max context %s", token_count, self._max_context_size) - raise ValueError( - f"Error - input tokens {token_count} exceeds max context {self._max_context_size}", - ) - - formatted_prompt: str = self._format_prompt(prompt) - return self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs) - - def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, - ) -> list[str]: - """Query the AWS LLM. - - There is only support for only_completed=True and return_sorted=False - right now. - """ - assert only_completed, "for now" - assert return_sorted is False, "for now" - - generated = self.basic_request(prompt, **kwargs) - return [generated] - - -class AWSMistral(AWSModel): - """Mistral family of models.""" - - def __init__( - self, - aws_provider: AWSProvider, - model: str, - max_context_size: int = 32768, - max_new_tokens: int = 1500, - **kwargs, - ) -> None: - """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" - super().__init__( - model=model, - max_context_size=max_context_size, - max_new_tokens=max_new_tokens, - **kwargs, - ) - self.aws_provider = aws_provider - self.provider = aws_provider.get_provider_name() - - self.kwargs["stop"] = "\n\n---" - - def _format_prompt(self, raw_prompt: str) -> str: - return " [INST] Human: " + raw_prompt + " [/INST] Assistant: " - - def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - base_args: dict[str, Any] = self.kwargs - for k, v in kwargs.items(): - base_args[k] = v - - n, base_args = self.aws_provider.sanitize_kwargs(base_args) - - query_args: dict[str, str | float] = {} - if isinstance(self.aws_provider, Bedrock): - query_args["prompt"] = prompt - elif isinstance(self.aws_provider, Sagemaker): - query_args["parameters"] = base_args - query_args["inputs"] = prompt - else: - raise ValueError("Error - provider not recognized") - - return (n, query_args) - - def _call_model(self, body: str) -> str: - response = self.aws_provider.call_model( - model_id=self._model_name, - body=body, - ) - if isinstance(self.aws_provider, Bedrock): - response_body = json.loads(response["body"].read()) - completion = response_body["outputs"][0]["text"] - elif isinstance(self.aws_provider, Sagemaker): - response_body = json.loads(response["Body"].read()) - completion = response_body[0]["generated_text"] - else: - raise ValueError("Error - provider not recognized") - - completion = completion.split(self.kwargs["stop"])[0] - return completion - - -class AWSAnthropic(AWSModel): - """Anthropic family of models.""" - - def __init__( - self, - aws_provider: AWSProvider, - model: str, - max_context_size: int = 200000, - max_new_tokens: int = 1500, - **kwargs, - ) -> None: - """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" - super().__init__( - model=model, - max_context_size=max_context_size, - max_new_tokens=max_new_tokens, - **kwargs, - ) - self.provider = aws_provider - - for k, v in kwargs.items(): - self.kwargs[k] = v - - def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - base_args: dict[str, Any] = self.kwargs - for k, v in kwargs.items(): - base_args[k] = v - - n, query_args = self.provider.sanitize_kwargs(base_args) - - # Anthropic models do not support the following parameters - query_args.pop("frequency_penalty", None) - query_args.pop("num_generations", None) - query_args.pop("presence_penalty", None) - query_args.pop("model", None) - - # we are using the Claude messages API - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - query_args["messages"] = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt, - }, - ], - }, - ] - query_args["anthropic_version"] = "bedrock-2023-05-31" - return (n, query_args) - - def _call_model(self, body: str) -> str: - response = self.provider.predictor.invoke_model( - modelId=self._model_name, - body=body, - ) - response_body = json.loads(response["body"].read()) - completion = response_body["content"][0]["text"] - return completion - - -class AWSMeta(AWSModel): - """Llama2 family of models.""" - - def __init__( - self, - aws_provider: AWSProvider, - model: str, - max_context_size: int = 4096, - max_new_tokens: int = 1500, - **kwargs, - ) -> None: - """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" - super().__init__( - model=model, - max_context_size=max_context_size, - max_new_tokens=max_new_tokens, - **kwargs, - ) - self.provider = aws_provider - - for k, v in kwargs.items(): - self.kwargs[k] = v - - self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens") - - def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - base_args: dict[str, Any] = self.kwargs - for k, v in kwargs.items(): - base_args[k] = v - - n, query_args = self.provider.sanitize_kwargs(base_args) - - # Meta models do not support the following parameters - query_args.pop("frequency_penalty", None) - query_args.pop("num_generations", None) - query_args.pop("presence_penalty", None) - query_args.pop("model", None) - - query_args["prompt"] = prompt - return (n, query_args) - - def _call_model(self, body: str) -> str: - response = self.provider.predictor.invoke_model( - modelId=self._model_name, - body=body, - ) - response_body = json.loads(response["body"].read()) - completion = response_body["generation"] - - stop = "\n\n" - completion = completion.split(stop)[0] - - return completion diff --git a/dsp/modules/aws_providers.py b/dsp/modules/aws_providers.py deleted file mode 100644 index ae3f0d4c3f..0000000000 --- a/dsp/modules/aws_providers.py +++ /dev/null @@ -1,128 +0,0 @@ -"""AWS providers for LMs.""" - -from abc import ABC, abstractmethod -from typing import Any, Optional - - -class AWSProvider(ABC): - """This abstract class adds support for AWS model providers such as Bedrock and SageMaker. - The subclasses such as Bedrock and Sagemaker implement the abstract method _call_model and work in conjunction with the AWSModel classes. - Usage Example: - bedrock = Bedrock(region_name="us-west-2") - bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) - bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) - bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) - - sagemaker = Sagemaker(region_name="us-west-2") - sagemaker_model = AWSMistral(sagemaker, "", **kwargs) - """ - - def __init__( - self, - region_name: str, - service_name: str, - profile_name: Optional[str] = None, - batch_n_enabled: bool = True, - ) -> None: - """_summary_. - - Args: - region_name (str, optional): The AWS region where this LM is hosted. - service_name (str): Used in context of invoking the boto3 API. - profile_name (str, optional): boto3 credentials profile. - batch_n_enabled (bool): If False, call the LM N times rather than batching. - """ - try: - import boto3 # pylint: disable=import-outside-toplevel - except ImportError as exc: - raise ImportError('Please install boto3 to use AWS models.') from exc - - if profile_name is None: - self.predictor = boto3.client(service_name, region_name=region_name) - else: - self.predictor = boto3.Session(profile_name=profile_name).client( - service_name, - region_name=region_name, - ) - - self.batch_n_enabled = batch_n_enabled - - def get_provider_name(self) -> str: - """Return the provider name.""" - return self.__class__.__name__ - - @abstractmethod - def call_model(self, model_id: str, body: str) -> str: - """Call the model and return the response.""" - - def sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> tuple[int, dict[str, Any]]: - """Ensure that input kwargs can be used by Bedrock or Sagemaker.""" - if "temperature" in query_kwargs: - if query_kwargs["temperature"] > 0.99: - query_kwargs["temperature"] = 0.99 - if query_kwargs["temperature"] < 0.01: - query_kwargs["temperature"] = 0.01 - - if "top_p" in query_kwargs: - if query_kwargs["top_p"] > 0.99: - query_kwargs["top_p"] = 0.99 - if query_kwargs["top_p"] < 0.01: - query_kwargs["top_p"] = 0.01 - - n = -1 - if not self.batch_n_enabled: - n = query_kwargs.pop('n', 1) - query_kwargs["num_generations"] = n - - return n, query_kwargs - - -class Bedrock(AWSProvider): - """This class adds support for Bedrock models.""" - - def __init__( - self, - region_name: str, - profile_name: Optional[str] = None, - batch_n_enabled: bool = False, # This has to be setup manually on Bedrock. - ) -> None: - """_summary_. - - Args: - region_name (str, optional): The AWS region where this LM is hosted. - profile_name (str, optional): boto3 credentials profile. - """ - super().__init__(region_name, "bedrock-runtime", profile_name, batch_n_enabled) - - def call_model(self, model_id: str, body: str) -> str: - return self.predictor.invoke_model( - modelId=model_id, - body=body, - accept="application/json", - contentType="application/json", - ) - - -class Sagemaker(AWSProvider): - """This class adds support for Sagemaker models.""" - - def __init__( - self, - region_name: str, - profile_name: Optional[str] = None, - ) -> None: - """_summary_. - - Args: - region_name (str, optional): The AWS region where this LM is hosted. - profile_name (str, optional): boto3 credentials profile. - """ - super().__init__(region_name, "runtime.sagemaker", profile_name) - - def call_model(self, model_id: str, body: str) -> str: - return self.predictor.invoke_endpoint( - EndpointName=model_id, - Body=body, - Accept="application/json", - ContentType="application/json", - ) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py new file mode 100644 index 0000000000..252c87fe86 --- /dev/null +++ b/dsp/modules/bedrock.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import json +from typing import Any, Optional + +from dsp.modules.aws_lm import AWSLM + + +class Bedrock(AWSLM): + def __init__( + self, + region_name: str, + model: str, + profile_name: Optional[str] = None, + input_output_ratio: int = 3, + max_new_tokens: int = 1500, + ) -> None: + """Use an AWS Bedrock language model. + NOTE: You must first configure your AWS credentials with the AWS CLI before using this model! + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + model (str, optional): An AWS Bedrock LM name. You can find available models with the AWS CLI as follows: aws bedrock list-foundation-models --query "modelSummaries[*].modelId". + temperature (float, optional): Default temperature for LM. Defaults to 0. + input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. + max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. + """ + super().__init__( + model=model, + service_name="bedrock-runtime", + region_name=region_name, + profile_name=profile_name, + truncate_long_prompts=False, + input_output_ratio=input_output_ratio, + max_new_tokens=max_new_tokens, + batch_n=True, # Bedrock does not support the `n` parameter + ) + self._validate_model(model) + self.provider = "claude" if "claude" in model.lower() else "bedrock" + + def _validate_model(self, model: str) -> None: + if "claude" not in model.lower(): + raise NotImplementedError("Only claude models are supported as of now") + + def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]: + base_args: dict[str, Any] = { + "max_tokens_to_sample": self._max_new_tokens, + } + for k, v in kwargs.items(): + base_args[k] = v + query_args: dict[str, Any] = self._sanitize_kwargs(base_args) + query_args["prompt"] = prompt + # AWS Bedrock forbids these keys + if "max_tokens" in query_args: + max_tokens: int = query_args["max_tokens"] + input_tokens: int = self._estimate_tokens(prompt) + max_tokens_to_sample: int = max_tokens - input_tokens + del query_args["max_tokens"] + query_args["max_tokens_to_sample"] = max_tokens_to_sample + return query_args + + def _call_model(self, body: str) -> str: + response = self.predictor.invoke_model( + modelId=self._model_name, + body=body, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response["body"].read()) + completion = response_body["completion"] + return completion + + def _extract_input_parameters( + self, body: dict[Any, Any], + ) -> dict[str, str | float | int]: + return body + + def _format_prompt(self, raw_prompt: str) -> str: + return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" diff --git a/dspy/__init__.py b/dspy/__init__.py index 2d5ba4f003..1c6539f59b 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -33,13 +33,7 @@ Together = dsp.Together HFModel = dsp.HFModel OllamaLocal = dsp.OllamaLocal - Bedrock = dsp.Bedrock -Sagemaker = dsp.Sagemaker -AWSModel = dsp.AWSModel -AWSMistral = dsp.AWSMistral -AWSAnthropic = dsp.AWSAnthropic -AWSMeta = dsp.AWSMeta configure = settings.configure context = settings.context diff --git a/pyproject.toml b/pyproject.toml index f564865f2f..2dafeab07b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ marqo = ["marqo"] pinecone = ["pinecone-client~=2.2.4"] weaviate = ["weaviate-client~=4.5.4"] milvus = ["pymilvus~=2.3.7"] -aws = ["boto3~=1.34.78"] docs = [ "sphinx>=4.3.0", "furo>=2023.3.27", @@ -59,6 +58,7 @@ docs = [ "autodoc_pydantic", "sphinx-reredirects>=0.1.2", "sphinx-automodapi==0.16.0", + ] dev = ["pytest>=6.2.5"] @@ -100,7 +100,6 @@ qdrant-client = { version = "^1.6.2", optional = true } pinecone-client = { version = "^2.2.4", optional = true } weaviate-client = { version = "^4.5.4", optional = true } pymilvus = { version = "^2.3.6", optional = true } -boto3 = { version = "^1.34.78", optional = true } sphinx = { version = ">=4.3.0", optional = true } furo = { version = ">=2023.3.27", optional = true } docutils = { version = "<0.17", optional = true } @@ -135,7 +134,6 @@ marqo = ["marqo"] pinecone = ["pinecone-client"] weaviate = ["weaviate-client"] milvus = ["pymilvus"] -aws = ["boto3"] postgres = ["psycopg2", "pgvector"] docs = [ "sphinx", diff --git a/tests/modules/test_aws_models.py b/tests/modules/test_aws_models.py deleted file mode 100644 index 89cedcc6f7..0000000000 --- a/tests/modules/test_aws_models.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Tests for AWS models. -Note: Requires configuration of your AWS credentials with the AWS CLI and creating sagemaker endpoints. -TODO: Create mock fixtures for pytest to remove the need for AWS credentials and endpoints. -""" - -import dsp -import dspy - -def get_lm(lm_provider: str, model_path: str, **kwargs) -> dsp.modules.lm.LM: - """get the language model""" - # extract model vendor and name from model name - # Model path format is / - model_vendor = model_path.split('/')[0] - model_name = model_path.split('/')[1] - - if lm_provider == 'Bedrock': - bedrock = dspy.Bedrock(region_name="us-west-2") - if model_vendor == 'mistral': - return dspy.AWSMistral(bedrock, model_name, **kwargs) - elif model_vendor == 'anthropic': - return dspy.AWSAnthropic(bedrock, model_name, **kwargs) - elif model_vendor == 'meta': - return dspy.AWSMeta(bedrock, model_name, **kwargs) - else: - raise ValueError("Model vendor missing or unsupported: Model path format is /") - elif lm_provider == 'Sagemaker': - sagemaker = dspy.Sagemaker(region_name="us-west-2") - if model_vendor == 'mistral': - return dspy.AWSMistral(sagemaker, model_name, **kwargs) - elif model_vendor == 'meta': - return dspy.AWSMeta(sagemaker, model_name, **kwargs) - else: - raise ValueError("Model vendor missing or unsupported: Model path format is /") - else: - raise ValueError(f"Unsupported model: {model_name}") - -def run_tests(): - """Test the providers and models""" - # Configure your AWS credentials with the AWS CLI before running this script - provider_model_tuples = [ - ('Bedrock', 'mistral/mistral.mixtral-8x7b-instruct-v0:1'), - ('Bedrock', 'anthropic/anthropic.claude-3-haiku-20240307-v1:0'), - ('Bedrock', 'anthropic/anthropic.claude-3-sonnet-20240229-v1:0'), - ('Bedrock', 'meta/meta.llama2-70b-chat-v1'), - # ('Sagemaker', 'mistral/'), # REPLACE YOUR_ENDPOINT_NAME with your sagemaker endpoint - ] - - predict_func = dspy.Predict("question -> answer") - for provider, model_path in provider_model_tuples: - print(f"Provider: {provider}, Model: {model_path}") - lm = get_lm(provider, model_path) - with dspy.context(lm=lm): - question = "What is the capital of France?" - answer = predict_func(question=question).answer - print(f"Question: {question}\nAnswer: {answer}\n\n") - - -if __name__ == "__main__": - run_tests() From 702c04fc25f351cfc12432d8276545b66cd70b0a Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Tue, 9 Apr 2024 13:38:55 -0500 Subject: [PATCH 04/15] feature(modules): added support for aws providers (bedrock and sagemaker) and models (mistral, anthropic, meta) --- docs/api/language_model_clients/aws.md | 61 +++++ dsp/modules/__init__.py | 7 +- dsp/modules/aws_models.py | 300 +++++++++++++++++++++++++ dsp/modules/aws_providers.py | 128 +++++++++++ dsp/modules/bedrock.py | 79 ------- dspy/__init__.py | 6 + pyproject.toml | 4 +- tests/modules/test_aws_models.py | 59 +++++ 8 files changed, 562 insertions(+), 82 deletions(-) create mode 100644 docs/api/language_model_clients/aws.md create mode 100644 dsp/modules/aws_models.py create mode 100644 dsp/modules/aws_providers.py delete mode 100644 dsp/modules/bedrock.py create mode 100644 tests/modules/test_aws_models.py diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws.md new file mode 100644 index 0000000000..afac264f3d --- /dev/null +++ b/docs/api/language_model_clients/aws.md @@ -0,0 +1,61 @@ +--- +sidebar_position: 9 +--- + +# dsp.AWSMistral, dsp.AWSAnthropic, dsp.AWSMeta + +### Usage + +```python +# Notes: +# 1. Install boto3 to use AWS models. +# 2. Configure your AWS credentials with the AWS CLI before using these models + +# initialize the bedrock aws provider +bedrock = Bedrock(region_name="us-west-2") +# For mixtral on Bedrock +lm = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) +# For haiku on Bedrock +lm = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) +# For llama2 on Bedrock +lm = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + +# initialize the sagemaker aws provider +sagemaker = Sagemaker(region_name="us-west-2") +# For mistral on Sagemaker +# Note: you need to create a Sagemaker endpoint for the mistral model first +lm = AWSMistral(sagemaker, "", **kwargs) + +``` + +### Constructor + +The constructor initializes the base class `LM` and the `AWSProvider` class. + +```python +class AWSMistral(AWSModel): + """Mistral family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 32768, + max_new_tokens: int = 1500, + **kwargs + ) -> None: +``` + +**Parameters:** +- `aws_provider` (AWSProvider): The aws provider to use. One of `Bedrock` or `Sagemaker`. +- `model` (_str_): Mistral AI pretrained models. Defaults to `mistral-medium-latest`. +- `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768. +- `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500. +- `**kwargs`: Additional language model arguments to pass to the API provider. + +### Methods + +Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation. + + +`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`. \ No newline at end of file diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index b7a739f504..935f65b403 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -1,6 +1,10 @@ from .anthropic import Claude +from .aws_models import AWSAnthropic, AWSMeta, AWSMistral, AWSModel + +# Below is obsolete. It has been replaced with Bedrock class in dsp/modules/aws_providers.py +# from .bedrock import * +from .aws_providers import Bedrock, Sagemaker from .azure_openai import AzureOpenAI -from .bedrock import * from .cache_utils import * from .clarifai import * from .cohere import * @@ -17,4 +21,3 @@ from .pyserini import * from .sbert import * from .sentence_vectorizer import * - diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py new file mode 100644 index 0000000000..3b93948897 --- /dev/null +++ b/dsp/modules/aws_models.py @@ -0,0 +1,300 @@ +"""AWS models for LMs.""" + +from __future__ import annotations + +import json +import logging +from abc import abstractmethod +from typing import Any + +from dsp.modules.aws_providers import AWSProvider, Bedrock, Sagemaker +from dsp.modules.lm import LM + +# Heuristic translating number of chars to tokens +# ~4 chars = 1 token +CHARS2TOKENS: int = 4 + + +class AWSModel(LM): + """This class adds support for an AWS model. + It is an abstract class and should not be instantiated directly. + Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta. + The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker. + Usage Example: + bedrock = Bedrock(region_name="us-west-2") + bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + + sagemaker = Sagemaker(region_name="us-west-2") + sagemaker_model = AWSMistral(sagemaker, "", **kwargs) + """ + + def __init__( + self, + model: str, + max_context_size: int, + max_new_tokens: int, + **kwargs, + ) -> None: + """_summary_. + + Args: + model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. + max_context_size (int): The maximum context size in tokens. + max_new_tokens (int): The maximum number of tokens to be sampled from the LM. + """ + super().__init__(model=model) + self._model_name: str = model + self._max_context_size: int = max_context_size + self._max_new_tokens: int = max_new_tokens + + self.kwargs = { + **self.kwargs, + **kwargs, + } + + @abstractmethod + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + pass + + @abstractmethod + def _call_model(self, body: str) -> str | list[str]: + """Call model, get generated input without the formatted prompt.""" + + def _estimate_tokens(self, text: str) -> int: + return len(text)/CHARS2TOKENS + + def _extract_input_parameters( + self, + body: dict[Any, Any], + ) -> dict[str, str | float | int]: + return body + + def _format_prompt(self, raw_prompt: str) -> str: + return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" + + def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]: + n, body = self._create_body(formatted_prompt, **kwargs) + json_body = json.dumps(body) + + if n > 1: + llm_out = [self._call_model(json_body) for _ in range(n)] + llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out] + else: + llm_out = self._call_model(json_body) + llm_out = llm_out.replace(formatted_prompt, "") + + self.history.append( + {"prompt": formatted_prompt, "response": llm_out, "kwargs": body}, + ) + return llm_out + + def basic_request(self, prompt, **kwargs) -> str | list[str]: + """Query the endpoint.""" + token_count = self._estimate_tokens(prompt) + if token_count > self._max_context_size: + logging.info("Error - input tokens %s exceeds max context %s", token_count, self._max_context_size) + raise ValueError( + f"Error - input tokens {token_count} exceeds max context {self._max_context_size}", + ) + + formatted_prompt: str = self._format_prompt(prompt) + return self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs) + + def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, + ) -> list[str]: + """Query the AWS LLM. + + There is only support for only_completed=True and return_sorted=False + right now. + """ + assert only_completed, "for now" + assert return_sorted is False, "for now" + + generated = self.basic_request(prompt, **kwargs) + return [generated] + + +class AWSMistral(AWSModel): + """Mistral family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 32768, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.aws_provider = aws_provider + self.provider = aws_provider.get_provider_name() + + self.kwargs["stop"] = "\n\n---" + + def _format_prompt(self, raw_prompt: str) -> str: + return " [INST] Human: " + raw_prompt + " [/INST] Assistant: " + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, base_args = self.aws_provider.sanitize_kwargs(base_args) + + query_args: dict[str, str | float] = {} + if isinstance(self.aws_provider, Bedrock): + query_args["prompt"] = prompt + elif isinstance(self.aws_provider, Sagemaker): + query_args["parameters"] = base_args + query_args["inputs"] = prompt + else: + raise ValueError("Error - provider not recognized") + + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.aws_provider.call_model( + model_id=self._model_name, + body=body, + ) + if isinstance(self.aws_provider, Bedrock): + response_body = json.loads(response["body"].read()) + completion = response_body["outputs"][0]["text"] + elif isinstance(self.aws_provider, Sagemaker): + response_body = json.loads(response["Body"].read()) + completion = response_body[0]["generated_text"] + else: + raise ValueError("Error - provider not recognized") + + completion = completion.split(self.kwargs["stop"])[0] + return completion + + +class AWSAnthropic(AWSModel): + """Anthropic family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 200000, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.provider = aws_provider + + for k, v in kwargs.items(): + self.kwargs[k] = v + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, query_args = self.provider.sanitize_kwargs(base_args) + + # Anthropic models do not support the following parameters + query_args.pop("frequency_penalty", None) + query_args.pop("num_generations", None) + query_args.pop("presence_penalty", None) + query_args.pop("model", None) + + # we are using the Claude messages API + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + query_args["messages"] = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + ], + }, + ] + query_args["anthropic_version"] = "bedrock-2023-05-31" + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.provider.predictor.invoke_model( + modelId=self._model_name, + body=body, + ) + response_body = json.loads(response["body"].read()) + completion = response_body["content"][0]["text"] + return completion + + +class AWSMeta(AWSModel): + """Llama2 family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 4096, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.provider = aws_provider + + for k, v in kwargs.items(): + self.kwargs[k] = v + + self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens") + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, query_args = self.provider.sanitize_kwargs(base_args) + + # Meta models do not support the following parameters + query_args.pop("frequency_penalty", None) + query_args.pop("num_generations", None) + query_args.pop("presence_penalty", None) + query_args.pop("model", None) + + query_args["prompt"] = prompt + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.provider.predictor.invoke_model( + modelId=self._model_name, + body=body, + ) + response_body = json.loads(response["body"].read()) + completion = response_body["generation"] + + stop = "\n\n" + completion = completion.split(stop)[0] + + return completion diff --git a/dsp/modules/aws_providers.py b/dsp/modules/aws_providers.py new file mode 100644 index 0000000000..ae3f0d4c3f --- /dev/null +++ b/dsp/modules/aws_providers.py @@ -0,0 +1,128 @@ +"""AWS providers for LMs.""" + +from abc import ABC, abstractmethod +from typing import Any, Optional + + +class AWSProvider(ABC): + """This abstract class adds support for AWS model providers such as Bedrock and SageMaker. + The subclasses such as Bedrock and Sagemaker implement the abstract method _call_model and work in conjunction with the AWSModel classes. + Usage Example: + bedrock = Bedrock(region_name="us-west-2") + bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + + sagemaker = Sagemaker(region_name="us-west-2") + sagemaker_model = AWSMistral(sagemaker, "", **kwargs) + """ + + def __init__( + self, + region_name: str, + service_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = True, + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + service_name (str): Used in context of invoking the boto3 API. + profile_name (str, optional): boto3 credentials profile. + batch_n_enabled (bool): If False, call the LM N times rather than batching. + """ + try: + import boto3 # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError('Please install boto3 to use AWS models.') from exc + + if profile_name is None: + self.predictor = boto3.client(service_name, region_name=region_name) + else: + self.predictor = boto3.Session(profile_name=profile_name).client( + service_name, + region_name=region_name, + ) + + self.batch_n_enabled = batch_n_enabled + + def get_provider_name(self) -> str: + """Return the provider name.""" + return self.__class__.__name__ + + @abstractmethod + def call_model(self, model_id: str, body: str) -> str: + """Call the model and return the response.""" + + def sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> tuple[int, dict[str, Any]]: + """Ensure that input kwargs can be used by Bedrock or Sagemaker.""" + if "temperature" in query_kwargs: + if query_kwargs["temperature"] > 0.99: + query_kwargs["temperature"] = 0.99 + if query_kwargs["temperature"] < 0.01: + query_kwargs["temperature"] = 0.01 + + if "top_p" in query_kwargs: + if query_kwargs["top_p"] > 0.99: + query_kwargs["top_p"] = 0.99 + if query_kwargs["top_p"] < 0.01: + query_kwargs["top_p"] = 0.01 + + n = -1 + if not self.batch_n_enabled: + n = query_kwargs.pop('n', 1) + query_kwargs["num_generations"] = n + + return n, query_kwargs + + +class Bedrock(AWSProvider): + """This class adds support for Bedrock models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = False, # This has to be setup manually on Bedrock. + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + profile_name (str, optional): boto3 credentials profile. + """ + super().__init__(region_name, "bedrock-runtime", profile_name, batch_n_enabled) + + def call_model(self, model_id: str, body: str) -> str: + return self.predictor.invoke_model( + modelId=model_id, + body=body, + accept="application/json", + contentType="application/json", + ) + + +class Sagemaker(AWSProvider): + """This class adds support for Sagemaker models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + profile_name (str, optional): boto3 credentials profile. + """ + super().__init__(region_name, "runtime.sagemaker", profile_name) + + def call_model(self, model_id: str, body: str) -> str: + return self.predictor.invoke_endpoint( + EndpointName=model_id, + Body=body, + Accept="application/json", + ContentType="application/json", + ) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py deleted file mode 100644 index 252c87fe86..0000000000 --- a/dsp/modules/bedrock.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any, Optional - -from dsp.modules.aws_lm import AWSLM - - -class Bedrock(AWSLM): - def __init__( - self, - region_name: str, - model: str, - profile_name: Optional[str] = None, - input_output_ratio: int = 3, - max_new_tokens: int = 1500, - ) -> None: - """Use an AWS Bedrock language model. - NOTE: You must first configure your AWS credentials with the AWS CLI before using this model! - - Args: - region_name (str, optional): The AWS region where this LM is hosted. - model (str, optional): An AWS Bedrock LM name. You can find available models with the AWS CLI as follows: aws bedrock list-foundation-models --query "modelSummaries[*].modelId". - temperature (float, optional): Default temperature for LM. Defaults to 0. - input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. - max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. - """ - super().__init__( - model=model, - service_name="bedrock-runtime", - region_name=region_name, - profile_name=profile_name, - truncate_long_prompts=False, - input_output_ratio=input_output_ratio, - max_new_tokens=max_new_tokens, - batch_n=True, # Bedrock does not support the `n` parameter - ) - self._validate_model(model) - self.provider = "claude" if "claude" in model.lower() else "bedrock" - - def _validate_model(self, model: str) -> None: - if "claude" not in model.lower(): - raise NotImplementedError("Only claude models are supported as of now") - - def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]: - base_args: dict[str, Any] = { - "max_tokens_to_sample": self._max_new_tokens, - } - for k, v in kwargs.items(): - base_args[k] = v - query_args: dict[str, Any] = self._sanitize_kwargs(base_args) - query_args["prompt"] = prompt - # AWS Bedrock forbids these keys - if "max_tokens" in query_args: - max_tokens: int = query_args["max_tokens"] - input_tokens: int = self._estimate_tokens(prompt) - max_tokens_to_sample: int = max_tokens - input_tokens - del query_args["max_tokens"] - query_args["max_tokens_to_sample"] = max_tokens_to_sample - return query_args - - def _call_model(self, body: str) -> str: - response = self.predictor.invoke_model( - modelId=self._model_name, - body=body, - accept="application/json", - contentType="application/json", - ) - response_body = json.loads(response["body"].read()) - completion = response_body["completion"] - return completion - - def _extract_input_parameters( - self, body: dict[Any, Any], - ) -> dict[str, str | float | int]: - return body - - def _format_prompt(self, raw_prompt: str) -> str: - return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" diff --git a/dspy/__init__.py b/dspy/__init__.py index 1c6539f59b..2d5ba4f003 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -33,7 +33,13 @@ Together = dsp.Together HFModel = dsp.HFModel OllamaLocal = dsp.OllamaLocal + Bedrock = dsp.Bedrock +Sagemaker = dsp.Sagemaker +AWSModel = dsp.AWSModel +AWSMistral = dsp.AWSMistral +AWSAnthropic = dsp.AWSAnthropic +AWSMeta = dsp.AWSMeta configure = settings.configure context = settings.context diff --git a/pyproject.toml b/pyproject.toml index 2dafeab07b..f564865f2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ marqo = ["marqo"] pinecone = ["pinecone-client~=2.2.4"] weaviate = ["weaviate-client~=4.5.4"] milvus = ["pymilvus~=2.3.7"] +aws = ["boto3~=1.34.78"] docs = [ "sphinx>=4.3.0", "furo>=2023.3.27", @@ -58,7 +59,6 @@ docs = [ "autodoc_pydantic", "sphinx-reredirects>=0.1.2", "sphinx-automodapi==0.16.0", - ] dev = ["pytest>=6.2.5"] @@ -100,6 +100,7 @@ qdrant-client = { version = "^1.6.2", optional = true } pinecone-client = { version = "^2.2.4", optional = true } weaviate-client = { version = "^4.5.4", optional = true } pymilvus = { version = "^2.3.6", optional = true } +boto3 = { version = "^1.34.78", optional = true } sphinx = { version = ">=4.3.0", optional = true } furo = { version = ">=2023.3.27", optional = true } docutils = { version = "<0.17", optional = true } @@ -134,6 +135,7 @@ marqo = ["marqo"] pinecone = ["pinecone-client"] weaviate = ["weaviate-client"] milvus = ["pymilvus"] +aws = ["boto3"] postgres = ["psycopg2", "pgvector"] docs = [ "sphinx", diff --git a/tests/modules/test_aws_models.py b/tests/modules/test_aws_models.py new file mode 100644 index 0000000000..89cedcc6f7 --- /dev/null +++ b/tests/modules/test_aws_models.py @@ -0,0 +1,59 @@ +"""Tests for AWS models. +Note: Requires configuration of your AWS credentials with the AWS CLI and creating sagemaker endpoints. +TODO: Create mock fixtures for pytest to remove the need for AWS credentials and endpoints. +""" + +import dsp +import dspy + +def get_lm(lm_provider: str, model_path: str, **kwargs) -> dsp.modules.lm.LM: + """get the language model""" + # extract model vendor and name from model name + # Model path format is / + model_vendor = model_path.split('/')[0] + model_name = model_path.split('/')[1] + + if lm_provider == 'Bedrock': + bedrock = dspy.Bedrock(region_name="us-west-2") + if model_vendor == 'mistral': + return dspy.AWSMistral(bedrock, model_name, **kwargs) + elif model_vendor == 'anthropic': + return dspy.AWSAnthropic(bedrock, model_name, **kwargs) + elif model_vendor == 'meta': + return dspy.AWSMeta(bedrock, model_name, **kwargs) + else: + raise ValueError("Model vendor missing or unsupported: Model path format is /") + elif lm_provider == 'Sagemaker': + sagemaker = dspy.Sagemaker(region_name="us-west-2") + if model_vendor == 'mistral': + return dspy.AWSMistral(sagemaker, model_name, **kwargs) + elif model_vendor == 'meta': + return dspy.AWSMeta(sagemaker, model_name, **kwargs) + else: + raise ValueError("Model vendor missing or unsupported: Model path format is /") + else: + raise ValueError(f"Unsupported model: {model_name}") + +def run_tests(): + """Test the providers and models""" + # Configure your AWS credentials with the AWS CLI before running this script + provider_model_tuples = [ + ('Bedrock', 'mistral/mistral.mixtral-8x7b-instruct-v0:1'), + ('Bedrock', 'anthropic/anthropic.claude-3-haiku-20240307-v1:0'), + ('Bedrock', 'anthropic/anthropic.claude-3-sonnet-20240229-v1:0'), + ('Bedrock', 'meta/meta.llama2-70b-chat-v1'), + # ('Sagemaker', 'mistral/'), # REPLACE YOUR_ENDPOINT_NAME with your sagemaker endpoint + ] + + predict_func = dspy.Predict("question -> answer") + for provider, model_path in provider_model_tuples: + print(f"Provider: {provider}, Model: {model_path}") + lm = get_lm(provider, model_path) + with dspy.context(lm=lm): + question = "What is the capital of France?" + answer = predict_func(question=question).answer + print(f"Question: {question}\nAnswer: {answer}\n\n") + + +if __name__ == "__main__": + run_tests() From f667bf27a94300054c55fc26af945f3d6dbbc27f Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Tue, 9 Apr 2024 14:36:49 -0500 Subject: [PATCH 05/15] feature(module): update lock file since we added boto3 --- poetry.lock | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 62b3b37ad3..d92a6a00e4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiohttp" @@ -530,6 +530,47 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "boto3" +version = "1.34.81" +description = "The AWS SDK for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.81-py3-none-any.whl", hash = "sha256:18224d206a8a775bcaa562d22ed3d07854934699190e12b52fcde87aac76a80e"}, + {file = "boto3-1.34.81.tar.gz", hash = "sha256:004dad209d37b3d2df88f41da13b7ad702a751904a335fac095897ff7a19f82b"}, +] + +[package.dependencies] +botocore = ">=1.34.81,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.81" +description = "Low-level, data-driven core of boto 3." +optional = true +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.81-py3-none-any.whl", hash = "sha256:85f6fd7c5715eeef7a236c50947de00f57d72e7439daed1125491014b70fab01"}, + {file = "botocore-1.34.81.tar.gz", hash = "sha256:f79bf122566cc1f09d71cc9ac9fcf52d47ba48b761cbc3f064017b36a3c40eb8"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.19.19)"] + [[package]] name = "cachetools" version = "5.3.3" @@ -2073,6 +2114,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = true +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "joblib" version = "1.3.2" @@ -4330,7 +4382,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4829,6 +4880,23 @@ files = [ {file = "ruff-0.3.4.tar.gz", hash = "sha256:f0f4484c6541a99862b693e13a151435a279b271cff20e37101116a21e2a1ad1"}, ] +[[package]] +name = "s3transfer" +version = "0.10.1" +description = "An Amazon S3 Transfer Manager" +optional = true +python-versions = ">= 3.8" +files = [ + {file = "s3transfer-0.10.1-py3-none-any.whl", hash = "sha256:ceb252b11bcf87080fb7850a224fb6e05c8a776bab8f2b64b7f25b969464839d"}, + {file = "s3transfer-0.10.1.tar.gz", hash = "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "safetensors" version = "0.4.2" @@ -5904,6 +5972,22 @@ files = [ {file = "ujson-5.9.0.tar.gz", hash = "sha256:89cc92e73d5501b8a7f48575eeb14ad27156ad092c2e9fc7e3cf949f07e75532"}, ] +[[package]] +name = "urllib3" +version = "1.26.18" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, + {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.1" @@ -6633,6 +6717,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] +aws = ["boto3"] chromadb = ["chromadb"] docs = ["autodoc_pydantic", "docutils", "furo", "m2r2", "myst-nb", "myst-parser", "sphinx", "sphinx-autobuild", "sphinx-automodapi", "sphinx-reredirects", "sphinx_rtd_theme"] marqo = ["marqo"] @@ -6645,4 +6730,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "582b069b4377e2b568dfdc2b43187f7c5cb3af12e1a4d6fc3a6fa7dde7edba10" +content-hash = "3988e0bedd832c87fda15828d8c6f08b2c3a9e75a9bca6d4201c5b8bdf5e3c9e" From 51135c0c683f59a2b472d0cde15399d9beb5cf5e Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Tue, 9 Apr 2024 14:37:47 -0500 Subject: [PATCH 06/15] feature(modules): fixed bugs and tested inspect_history --- dsp/modules/aws_models.py | 14 ++++++++------ tests/modules/test_aws_models.py | 5 ++++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py index 3b93948897..ec50332972 100644 --- a/dsp/modules/aws_models.py +++ b/dsp/modules/aws_models.py @@ -201,7 +201,8 @@ def __init__( max_new_tokens=max_new_tokens, **kwargs, ) - self.provider = aws_provider + self.aws_provider = aws_provider + self.provider = aws_provider.get_provider_name() for k, v in kwargs.items(): self.kwargs[k] = v @@ -211,7 +212,7 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa for k, v in kwargs.items(): base_args[k] = v - n, query_args = self.provider.sanitize_kwargs(base_args) + n, query_args = self.aws_provider.sanitize_kwargs(base_args) # Anthropic models do not support the following parameters query_args.pop("frequency_penalty", None) @@ -236,7 +237,7 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa return (n, query_args) def _call_model(self, body: str) -> str: - response = self.provider.predictor.invoke_model( + response = self.aws_provider.predictor.invoke_model( modelId=self._model_name, body=body, ) @@ -263,7 +264,8 @@ def __init__( max_new_tokens=max_new_tokens, **kwargs, ) - self.provider = aws_provider + self.aws_provider = aws_provider + self.provider = aws_provider.get_provider_name() for k, v in kwargs.items(): self.kwargs[k] = v @@ -275,7 +277,7 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa for k, v in kwargs.items(): base_args[k] = v - n, query_args = self.provider.sanitize_kwargs(base_args) + n, query_args = self.aws_provider.sanitize_kwargs(base_args) # Meta models do not support the following parameters query_args.pop("frequency_penalty", None) @@ -287,7 +289,7 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa return (n, query_args) def _call_model(self, body: str) -> str: - response = self.provider.predictor.invoke_model( + response = self.aws_provider.predictor.invoke_model( modelId=self._model_name, body=body, ) diff --git a/tests/modules/test_aws_models.py b/tests/modules/test_aws_models.py index 89cedcc6f7..fd0794f7cd 100644 --- a/tests/modules/test_aws_models.py +++ b/tests/modules/test_aws_models.py @@ -52,7 +52,10 @@ def run_tests(): with dspy.context(lm=lm): question = "What is the capital of France?" answer = predict_func(question=question).answer - print(f"Question: {question}\nAnswer: {answer}\n\n") + print(f"Question: {question}\nAnswer: {answer}") + print("---------------------------------") + lm.inspect_history() + print("---------------------------------\n") if __name__ == "__main__": From f4dc4955733cbaddd5f2aa05e6859009c8eaa83c Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Tue, 9 Apr 2024 20:58:30 -0500 Subject: [PATCH 07/15] feature(modules): deleteing obsolete code --- dsp/modules/aws_lm.py | 186 ------------------------------------------ 1 file changed, 186 deletions(-) delete mode 100644 dsp/modules/aws_lm.py diff --git a/dsp/modules/aws_lm.py b/dsp/modules/aws_lm.py deleted file mode 100644 index 366d0d1ed7..0000000000 --- a/dsp/modules/aws_lm.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -A generalized AWS LLM. -""" - -from __future__ import annotations - -import json -import logging -from abc import abstractmethod -from typing import Any, Literal, Optional - -from dsp.modules.lm import LM - -# Heuristic translating number of chars to tokens -# ~4 chars = 1 token -CHARS2TOKENS: int = 4 - - -class AWSLM(LM): - """ - This class adds support for an AWS model - """ - - def __init__( - self, - model: str, - region_name: str, - service_name: str, - max_new_tokens: int, - profile_name: Optional[str] = None, - truncate_long_prompts: bool = False, - input_output_ratio: int = 3, - batch_n: bool = True, - ) -> None: - """_summary_ - - Args: - - service_name (str): Used in context of invoking the boto3 API. - region_name (str, optional): The AWS region where this LM is hosted. - model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. - max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. - input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. - temperature (float, optional): _description_. Defaults to 0.0. - truncate_long_prompts (bool, optional): If True, remove extremely long inputs to context. Defaults to False. - batch_n (bool, False): If False, call the LM N times rather than batching. Not all AWS models support the n parameter. - """ - super().__init__(model=model) - # AWS doesn't have an equivalent of max_tokens so let's clarify - # that the expected input is going to be about 2x as long as the output - self.kwargs["max_tokens"] = max_new_tokens * input_output_ratio - self._max_new_tokens: int = max_new_tokens - self._model_name: str = model - self._truncate_long_prompt_prompts: bool = truncate_long_prompts - self._batch_n: bool = batch_n - - import boto3 - - if profile_name is None: - self.predictor = boto3.client(service_name, region_name=region_name) - else: - self.predictor = boto3.Session(profile_name=profile_name).client( - service_name, region_name=region_name, - ) - - @abstractmethod - def _create_body(self, prompt: str, **kwargs): - pass - - def _sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> dict[str, Any]: - """Ensure that input kwargs can be used by Bedrock or Sagemaker.""" - base_args: dict[str, Any] = {"temperature": self.kwargs["temperature"]} - - for k, v in base_args.items(): - if k not in query_kwargs: - query_kwargs[k] = v - if query_kwargs["temperature"] > 1.0: - query_kwargs["temperature"] = 0.99 - if query_kwargs["temperature"] < 0.01: - query_kwargs["temperature"] = 0.01 - - return query_kwargs - - @abstractmethod - def _call_model(self, body: str) -> str | list[str]: - """Call model, get generated input without the formatted prompt""" - pass - - @abstractmethod - def _extract_input_parameters( - self, body: dict[Any, Any], - ) -> dict[str, str | float | int]: - pass - - def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]: - body = self._create_body(formatted_prompt, **kwargs) - json_body = json.dumps(body) - llm_out: str | list[str] = self._call_model(json_body) - if isinstance(llm_out, str): - llm_out = llm_out.replace(formatted_prompt, "") - else: - llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out] - self.history.append( - {"prompt": formatted_prompt, "response": llm_out, "kwargs": body}, - ) - return llm_out - - def basic_request(self, prompt, **kwargs) -> str | list[str]: - """Query the endpoint.""" - - # Remove any texts that are too long - formatted_prompt: str - if self._truncate_long_prompt_prompts: - truncated_prompt: str = self._truncate_prompt(prompt) - formatted_prompt = self._format_prompt(truncated_prompt) - else: - formatted_prompt = self._format_prompt(prompt) - - llm_out: str | list[str] - if "n" in kwargs.keys(): - if self._batch_n: - llm_out = self._simple_api_call( - formatted_prompt=formatted_prompt, **kwargs, - ) - else: - del kwargs["n"] - llm_out = [] - for _ in range(0, kwargs["n"]): - generated: str | list[str] = self._simple_api_call( - formatted_prompt=formatted_prompt, **kwargs, - ) - if isinstance(generated, str): - llm_out.append(generated) - else: - raise TypeError("Error, list type was returned from LM call") - else: - llm_out = self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs) - - return llm_out - - def _estimate_tokens(self, text: str) -> int: - return len(text) * CHARS2TOKENS - - @abstractmethod - def _format_prompt(self, raw_prompt: str) -> str: - pass - - def _truncate_prompt( - self, - input_text: str, - remove_beginning_or_ending: Literal["beginning", "ending"] = "beginning", - max_input_tokens: int = 2500, - ) -> str: - """Reformat inputs such that they do not overflow context size limitation.""" - token_count = self._estimate_tokens(input_text) - if token_count > self.kwargs["max_tokens"]: - logging.info("Excessive prompt found in llm input") - logging.info("Truncating texts to avoid error") - max_chars: int = CHARS2TOKENS * max_input_tokens - truncated_text: str - if remove_beginning_or_ending == "ending": - truncated_text = input_text[0:max_chars] - else: - truncated_text = input_text[-max_chars:] - return truncated_text - return input_text - - def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, - ) -> list[str]: - """ - Query the AWS LLM. - - There is only support for only_completed=True and return_sorted=False - right now. - """ - if not only_completed: - raise NotImplementedError("Error, only_completed not yet supported!") - if return_sorted: - raise NotImplementedError("Error, return_sorted not yet supported!") - generated = self.basic_request(prompt, **kwargs) - return [generated] From d2707eb7d019710ba10b75176be7186af2ac9f8a Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Thu, 11 Apr 2024 10:48:57 -0500 Subject: [PATCH 08/15] docs(aws.md): changed lm calls to dspy.Bedrock etc. --- docs/api/language_model_clients/aws.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws.md index afac264f3d..da06e8e8e4 100644 --- a/docs/api/language_model_clients/aws.md +++ b/docs/api/language_model_clients/aws.md @@ -12,19 +12,19 @@ sidebar_position: 9 # 2. Configure your AWS credentials with the AWS CLI before using these models # initialize the bedrock aws provider -bedrock = Bedrock(region_name="us-west-2") +bedrock = dspy.Bedrock(region_name="us-west-2") # For mixtral on Bedrock -lm = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) +lm = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) # For haiku on Bedrock -lm = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) +lm = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) # For llama2 on Bedrock -lm = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) +lm = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) # initialize the sagemaker aws provider -sagemaker = Sagemaker(region_name="us-west-2") +sagemaker = dspy.Sagemaker(region_name="us-west-2") # For mistral on Sagemaker # Note: you need to create a Sagemaker endpoint for the mistral model first -lm = AWSMistral(sagemaker, "", **kwargs) +lm = dspy.AWSMistral(sagemaker, "", **kwargs) ``` From f33eeff3389c9c8c0772ce9c6e5adf4d3c2f3b2b Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Thu, 11 Apr 2024 10:50:48 -0500 Subject: [PATCH 09/15] docs(aws.md): fixed typo --- docs/api/language_model_clients/aws.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws.md index da06e8e8e4..0b9386f712 100644 --- a/docs/api/language_model_clients/aws.md +++ b/docs/api/language_model_clients/aws.md @@ -30,7 +30,7 @@ lm = dspy.AWSMistral(sagemaker, "", **kwargs) ### Constructor -The constructor initializes the base class `LM` and the `AWSProvider` class. +The constructor initializes the base class `LM` and the `AWSModel` class. ```python class AWSMistral(AWSModel): From fbe22a233efab5a7839df0b02981ad0501e887f0 Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Thu, 11 Apr 2024 11:23:35 -0500 Subject: [PATCH 10/15] docs(aws.md): added commentary on AWSModel methods --- docs/api/language_model_clients/aws.md | 33 ++++++++++++++++++++++---- dsp/modules/aws_models.py | 12 +++++----- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws.md index 0b9386f712..4f3c4b0b65 100644 --- a/docs/api/language_model_clients/aws.md +++ b/docs/api/language_model_clients/aws.md @@ -30,7 +30,7 @@ lm = dspy.AWSMistral(sagemaker, "", **kwargs) ### Constructor -The constructor initializes the base class `LM` and the `AWSModel` class. +The `AWSMistral` constructor initializes the base class `AWSModel` which itself inherits from the `LM` class. ```python class AWSMistral(AWSModel): @@ -47,15 +47,40 @@ class AWSMistral(AWSModel): ``` **Parameters:** -- `aws_provider` (AWSProvider): The aws provider to use. One of `Bedrock` or `Sagemaker`. -- `model` (_str_): Mistral AI pretrained models. Defaults to `mistral-medium-latest`. +- `aws_provider` (AWSProvider): The aws provider to use. One of `dspy.Bedrock` or `dspy.Sagemaker`. +- `model` (_str_): Mistral AI pretrained models. For Bedrock, this is the Model ID in https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns. For Sagemaker, this is the endpoint name. - `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768. - `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500. - `**kwargs`: Additional language model arguments to pass to the API provider. ### Methods -Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation. +```python +def _format_prompt(self, raw_prompt: str) -> str: +``` +This function formats the prompt for the model. Refer to the model card for the specific formatting required. + +
+ +```python +def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: +``` +This function creates the body of the request to the model. It takes the prompt and any additional keyword arguments and returns a tuple of the number of tokens to generate and a dictionary of keys including the prompt used to create the body of the request. + +
+ +```python +def _call_model(self, body: str) -> str: +``` +This function calls the model using the provider `call_model()` function and extracts the generated text (completion) from the provider-specific response. + +
+ +The above model-specific methods are called by the `AWSModel::basic_request()` method, which is the main method for querying the model. This method takes the prompt and any additional keyword arguments and calls the `AWSModel::_simple_api_call()` which then delegates to the model-specific `_create_body()` and `_call_model()` methods to create the body of the request, call the model and extract the generated text. + + +Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation for information on the `LM` base class functionality. +
`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`. \ No newline at end of file diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py index ec50332972..ec931e1561 100644 --- a/dsp/modules/aws_models.py +++ b/dsp/modules/aws_models.py @@ -21,13 +21,13 @@ class AWSModel(LM): Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta. The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker. Usage Example: - bedrock = Bedrock(region_name="us-west-2") - bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) - bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) - bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + bedrock = dspy.Bedrock(region_name="us-west-2") + bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) - sagemaker = Sagemaker(region_name="us-west-2") - sagemaker_model = AWSMistral(sagemaker, "", **kwargs) + sagemaker = dspy.Sagemaker(region_name="us-west-2") + sagemaker_model = dspy.AWSMistral(sagemaker, "", **kwargs) """ def __init__( From 86800688ed3bedbfcf4c9bc06598ffef10530d65 Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Thu, 11 Apr 2024 11:33:29 -0500 Subject: [PATCH 11/15] feature(aws_models): anthropic_version param should be set in constructor and is for bedrock only --- dsp/modules/aws_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py index ec931e1561..94b07cad43 100644 --- a/dsp/modules/aws_models.py +++ b/dsp/modules/aws_models.py @@ -204,6 +204,9 @@ def __init__( self.aws_provider = aws_provider self.provider = aws_provider.get_provider_name() + if isinstance(self.aws_provider, Bedrock): + self.kwargs["anthropic_version"] = "bedrock-2023-05-31" + for k, v in kwargs.items(): self.kwargs[k] = v @@ -233,7 +236,6 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa ], }, ] - query_args["anthropic_version"] = "bedrock-2023-05-31" return (n, query_args) def _call_model(self, body: str) -> str: From 1b45f721c2f24d65e5922177f75cd5d9f66a7ed1 Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Thu, 11 Apr 2024 11:37:15 -0500 Subject: [PATCH 12/15] docs(aws): removed pylint disable and tweaked import error message --- dsp/modules/aws_providers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsp/modules/aws_providers.py b/dsp/modules/aws_providers.py index ae3f0d4c3f..337ef6cd2a 100644 --- a/dsp/modules/aws_providers.py +++ b/dsp/modules/aws_providers.py @@ -33,9 +33,9 @@ def __init__( batch_n_enabled (bool): If False, call the LM N times rather than batching. """ try: - import boto3 # pylint: disable=import-outside-toplevel + import boto3 except ImportError as exc: - raise ImportError('Please install boto3 to use AWS models.') from exc + raise ImportError('pip install boto3 to use AWS models.') from exc if profile_name is None: self.predictor = boto3.client(service_name, region_name=region_name) From 87540dac2c09888dab2367389bdb153b4d1d4fd4 Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 <54859892+arnavsinghvi11@users.noreply.github.com> Date: Fri, 12 Apr 2024 09:59:22 -0700 Subject: [PATCH 13/15] Update aws.md --- docs/api/language_model_clients/aws.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws.md index 4f3c4b0b65..143c484460 100644 --- a/docs/api/language_model_clients/aws.md +++ b/docs/api/language_model_clients/aws.md @@ -2,7 +2,7 @@ sidebar_position: 9 --- -# dsp.AWSMistral, dsp.AWSAnthropic, dsp.AWSMeta +# dspy.AWSMistral, dspy.AWSAnthropic, dspy.AWSMeta ### Usage @@ -83,4 +83,4 @@ Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients
-`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`. \ No newline at end of file +`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`. From 4326033802bf44a840d2f58ece5107c5ad91878c Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 <54859892+arnavsinghvi11@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:04:26 -0700 Subject: [PATCH 14/15] Update aws_providers.py --- dsp/modules/aws_providers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dsp/modules/aws_providers.py b/dsp/modules/aws_providers.py index 337ef6cd2a..536867818b 100644 --- a/dsp/modules/aws_providers.py +++ b/dsp/modules/aws_providers.py @@ -8,13 +8,13 @@ class AWSProvider(ABC): """This abstract class adds support for AWS model providers such as Bedrock and SageMaker. The subclasses such as Bedrock and Sagemaker implement the abstract method _call_model and work in conjunction with the AWSModel classes. Usage Example: - bedrock = Bedrock(region_name="us-west-2") - bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) - bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) - bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + bedrock = dspy.Bedrock(region_name="us-west-2") + bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) - sagemaker = Sagemaker(region_name="us-west-2") - sagemaker_model = AWSMistral(sagemaker, "", **kwargs) + sagemaker = dspy.Sagemaker(region_name="us-west-2") + sagemaker_model = dspy.AWSMistral(sagemaker, "", **kwargs) """ def __init__( From b65e17f02045fb7243e60c657015903a92b694e0 Mon Sep 17 00:00:00 2001 From: Dhar Rawal Date: Fri, 12 Apr 2024 13:10:05 -0500 Subject: [PATCH 15/15] feature(docs): Documented the aws_provider class. Renamed the aws md files. Updated docs in aws_providers.py --- .../{aws.md => aws_models.md} | 0 .../language_model_clients/aws_providers.md | 53 +++++++++++++++++++ dsp/modules/aws_providers.py | 2 +- 3 files changed, 54 insertions(+), 1 deletion(-) rename docs/api/language_model_clients/{aws.md => aws_models.md} (100%) create mode 100644 docs/api/language_model_clients/aws_providers.md diff --git a/docs/api/language_model_clients/aws.md b/docs/api/language_model_clients/aws_models.md similarity index 100% rename from docs/api/language_model_clients/aws.md rename to docs/api/language_model_clients/aws_models.md diff --git a/docs/api/language_model_clients/aws_providers.md b/docs/api/language_model_clients/aws_providers.md new file mode 100644 index 0000000000..46950db9e0 --- /dev/null +++ b/docs/api/language_model_clients/aws_providers.md @@ -0,0 +1,53 @@ +--- +sidebar_position: 9 +--- + +# dspy.Bedrock, dspy.Sagemaker + +### Usage + +The `AWSProvider` class is the base class for the AWS providers - `dspy.Bedrock` and `dspy.Sagemaker`. An instance of one of these providers is passed to the constructor when creating an instance of an AWS model class (e.g., `dspy.AWSMistral`) that is ultimately used to query the model. + +```python +# Notes: +# 1. Install boto3 to use AWS models. +# 2. Configure your AWS credentials with the AWS CLI before using these models + +# initialize the bedrock aws provider +bedrock = dspy.Bedrock(region_name="us-west-2") + +# initialize the sagemaker aws provider +sagemaker = dspy.Sagemaker(region_name="us-west-2") +``` + +### Constructor + +The `Bedrock` constructor initializes the base class `AWSProvider`. + +```python +class Bedrock(AWSProvider): + """This class adds support for Bedrock models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = False, # This has to be setup manually on Bedrock. + ) -> None: +``` + +**Parameters:** +- `region_name` (str): The AWS region where this LM is hosted. +- `profile_name` (str, optional): boto3 credentials profile. +- `batch_n_enabled` (bool): If False, call the LM N times rather than batching. + +### Methods + +```python +def call_model(self, model_id: str, body: str) -> str: +``` +This function implements the actual invocation of the model on AWS using the boto3 provider. + +
+ +`Sagemaker` works exactly the same as `Bedrock`. \ No newline at end of file diff --git a/dsp/modules/aws_providers.py b/dsp/modules/aws_providers.py index 337ef6cd2a..9bf1f936e2 100644 --- a/dsp/modules/aws_providers.py +++ b/dsp/modules/aws_providers.py @@ -89,7 +89,7 @@ def __init__( """_summary_. Args: - region_name (str, optional): The AWS region where this LM is hosted. + region_name (str): The AWS region where this LM is hosted. profile_name (str, optional): boto3 credentials profile. """ super().__init__(region_name, "bedrock-runtime", profile_name, batch_n_enabled)