diff --git a/docs/api/language_model_clients/aws_models.md b/docs/api/language_model_clients/aws_models.md new file mode 100644 index 0000000000..143c484460 --- /dev/null +++ b/docs/api/language_model_clients/aws_models.md @@ -0,0 +1,86 @@ +--- +sidebar_position: 9 +--- + +# dspy.AWSMistral, dspy.AWSAnthropic, dspy.AWSMeta + +### Usage + +```python +# Notes: +# 1. Install boto3 to use AWS models. +# 2. Configure your AWS credentials with the AWS CLI before using these models + +# initialize the bedrock aws provider +bedrock = dspy.Bedrock(region_name="us-west-2") +# For mixtral on Bedrock +lm = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) +# For haiku on Bedrock +lm = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) +# For llama2 on Bedrock +lm = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + +# initialize the sagemaker aws provider +sagemaker = dspy.Sagemaker(region_name="us-west-2") +# For mistral on Sagemaker +# Note: you need to create a Sagemaker endpoint for the mistral model first +lm = dspy.AWSMistral(sagemaker, "", **kwargs) + +``` + +### Constructor + +The `AWSMistral` constructor initializes the base class `AWSModel` which itself inherits from the `LM` class. + +```python +class AWSMistral(AWSModel): + """Mistral family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 32768, + max_new_tokens: int = 1500, + **kwargs + ) -> None: +``` + +**Parameters:** +- `aws_provider` (AWSProvider): The aws provider to use. One of `dspy.Bedrock` or `dspy.Sagemaker`. +- `model` (_str_): Mistral AI pretrained models. For Bedrock, this is the Model ID in https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns. For Sagemaker, this is the endpoint name. +- `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768. +- `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500. +- `**kwargs`: Additional language model arguments to pass to the API provider. + +### Methods + +```python +def _format_prompt(self, raw_prompt: str) -> str: +``` +This function formats the prompt for the model. Refer to the model card for the specific formatting required. + +
+ +```python +def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: +``` +This function creates the body of the request to the model. It takes the prompt and any additional keyword arguments and returns a tuple of the number of tokens to generate and a dictionary of keys including the prompt used to create the body of the request. + +
+ +```python +def _call_model(self, body: str) -> str: +``` +This function calls the model using the provider `call_model()` function and extracts the generated text (completion) from the provider-specific response. + +
+ +The above model-specific methods are called by the `AWSModel::basic_request()` method, which is the main method for querying the model. This method takes the prompt and any additional keyword arguments and calls the `AWSModel::_simple_api_call()` which then delegates to the model-specific `_create_body()` and `_call_model()` methods to create the body of the request, call the model and extract the generated text. + + +Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation for information on the `LM` base class functionality. + +
+ +`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`. diff --git a/docs/api/language_model_clients/aws_providers.md b/docs/api/language_model_clients/aws_providers.md new file mode 100644 index 0000000000..46950db9e0 --- /dev/null +++ b/docs/api/language_model_clients/aws_providers.md @@ -0,0 +1,53 @@ +--- +sidebar_position: 9 +--- + +# dspy.Bedrock, dspy.Sagemaker + +### Usage + +The `AWSProvider` class is the base class for the AWS providers - `dspy.Bedrock` and `dspy.Sagemaker`. An instance of one of these providers is passed to the constructor when creating an instance of an AWS model class (e.g., `dspy.AWSMistral`) that is ultimately used to query the model. + +```python +# Notes: +# 1. Install boto3 to use AWS models. +# 2. Configure your AWS credentials with the AWS CLI before using these models + +# initialize the bedrock aws provider +bedrock = dspy.Bedrock(region_name="us-west-2") + +# initialize the sagemaker aws provider +sagemaker = dspy.Sagemaker(region_name="us-west-2") +``` + +### Constructor + +The `Bedrock` constructor initializes the base class `AWSProvider`. + +```python +class Bedrock(AWSProvider): + """This class adds support for Bedrock models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = False, # This has to be setup manually on Bedrock. + ) -> None: +``` + +**Parameters:** +- `region_name` (str): The AWS region where this LM is hosted. +- `profile_name` (str, optional): boto3 credentials profile. +- `batch_n_enabled` (bool): If False, call the LM N times rather than batching. + +### Methods + +```python +def call_model(self, model_id: str, body: str) -> str: +``` +This function implements the actual invocation of the model on AWS using the boto3 provider. + +
+ +`Sagemaker` works exactly the same as `Bedrock`. \ No newline at end of file diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index b7a739f504..935f65b403 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -1,6 +1,10 @@ from .anthropic import Claude +from .aws_models import AWSAnthropic, AWSMeta, AWSMistral, AWSModel + +# Below is obsolete. It has been replaced with Bedrock class in dsp/modules/aws_providers.py +# from .bedrock import * +from .aws_providers import Bedrock, Sagemaker from .azure_openai import AzureOpenAI -from .bedrock import * from .cache_utils import * from .clarifai import * from .cohere import * @@ -17,4 +21,3 @@ from .pyserini import * from .sbert import * from .sentence_vectorizer import * - diff --git a/dsp/modules/aws_lm.py b/dsp/modules/aws_lm.py deleted file mode 100644 index 366d0d1ed7..0000000000 --- a/dsp/modules/aws_lm.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -A generalized AWS LLM. -""" - -from __future__ import annotations - -import json -import logging -from abc import abstractmethod -from typing import Any, Literal, Optional - -from dsp.modules.lm import LM - -# Heuristic translating number of chars to tokens -# ~4 chars = 1 token -CHARS2TOKENS: int = 4 - - -class AWSLM(LM): - """ - This class adds support for an AWS model - """ - - def __init__( - self, - model: str, - region_name: str, - service_name: str, - max_new_tokens: int, - profile_name: Optional[str] = None, - truncate_long_prompts: bool = False, - input_output_ratio: int = 3, - batch_n: bool = True, - ) -> None: - """_summary_ - - Args: - - service_name (str): Used in context of invoking the boto3 API. - region_name (str, optional): The AWS region where this LM is hosted. - model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. - max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. - input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. - temperature (float, optional): _description_. Defaults to 0.0. - truncate_long_prompts (bool, optional): If True, remove extremely long inputs to context. Defaults to False. - batch_n (bool, False): If False, call the LM N times rather than batching. Not all AWS models support the n parameter. - """ - super().__init__(model=model) - # AWS doesn't have an equivalent of max_tokens so let's clarify - # that the expected input is going to be about 2x as long as the output - self.kwargs["max_tokens"] = max_new_tokens * input_output_ratio - self._max_new_tokens: int = max_new_tokens - self._model_name: str = model - self._truncate_long_prompt_prompts: bool = truncate_long_prompts - self._batch_n: bool = batch_n - - import boto3 - - if profile_name is None: - self.predictor = boto3.client(service_name, region_name=region_name) - else: - self.predictor = boto3.Session(profile_name=profile_name).client( - service_name, region_name=region_name, - ) - - @abstractmethod - def _create_body(self, prompt: str, **kwargs): - pass - - def _sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> dict[str, Any]: - """Ensure that input kwargs can be used by Bedrock or Sagemaker.""" - base_args: dict[str, Any] = {"temperature": self.kwargs["temperature"]} - - for k, v in base_args.items(): - if k not in query_kwargs: - query_kwargs[k] = v - if query_kwargs["temperature"] > 1.0: - query_kwargs["temperature"] = 0.99 - if query_kwargs["temperature"] < 0.01: - query_kwargs["temperature"] = 0.01 - - return query_kwargs - - @abstractmethod - def _call_model(self, body: str) -> str | list[str]: - """Call model, get generated input without the formatted prompt""" - pass - - @abstractmethod - def _extract_input_parameters( - self, body: dict[Any, Any], - ) -> dict[str, str | float | int]: - pass - - def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]: - body = self._create_body(formatted_prompt, **kwargs) - json_body = json.dumps(body) - llm_out: str | list[str] = self._call_model(json_body) - if isinstance(llm_out, str): - llm_out = llm_out.replace(formatted_prompt, "") - else: - llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out] - self.history.append( - {"prompt": formatted_prompt, "response": llm_out, "kwargs": body}, - ) - return llm_out - - def basic_request(self, prompt, **kwargs) -> str | list[str]: - """Query the endpoint.""" - - # Remove any texts that are too long - formatted_prompt: str - if self._truncate_long_prompt_prompts: - truncated_prompt: str = self._truncate_prompt(prompt) - formatted_prompt = self._format_prompt(truncated_prompt) - else: - formatted_prompt = self._format_prompt(prompt) - - llm_out: str | list[str] - if "n" in kwargs.keys(): - if self._batch_n: - llm_out = self._simple_api_call( - formatted_prompt=formatted_prompt, **kwargs, - ) - else: - del kwargs["n"] - llm_out = [] - for _ in range(0, kwargs["n"]): - generated: str | list[str] = self._simple_api_call( - formatted_prompt=formatted_prompt, **kwargs, - ) - if isinstance(generated, str): - llm_out.append(generated) - else: - raise TypeError("Error, list type was returned from LM call") - else: - llm_out = self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs) - - return llm_out - - def _estimate_tokens(self, text: str) -> int: - return len(text) * CHARS2TOKENS - - @abstractmethod - def _format_prompt(self, raw_prompt: str) -> str: - pass - - def _truncate_prompt( - self, - input_text: str, - remove_beginning_or_ending: Literal["beginning", "ending"] = "beginning", - max_input_tokens: int = 2500, - ) -> str: - """Reformat inputs such that they do not overflow context size limitation.""" - token_count = self._estimate_tokens(input_text) - if token_count > self.kwargs["max_tokens"]: - logging.info("Excessive prompt found in llm input") - logging.info("Truncating texts to avoid error") - max_chars: int = CHARS2TOKENS * max_input_tokens - truncated_text: str - if remove_beginning_or_ending == "ending": - truncated_text = input_text[0:max_chars] - else: - truncated_text = input_text[-max_chars:] - return truncated_text - return input_text - - def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, - ) -> list[str]: - """ - Query the AWS LLM. - - There is only support for only_completed=True and return_sorted=False - right now. - """ - if not only_completed: - raise NotImplementedError("Error, only_completed not yet supported!") - if return_sorted: - raise NotImplementedError("Error, return_sorted not yet supported!") - generated = self.basic_request(prompt, **kwargs) - return [generated] diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py new file mode 100644 index 0000000000..94b07cad43 --- /dev/null +++ b/dsp/modules/aws_models.py @@ -0,0 +1,304 @@ +"""AWS models for LMs.""" + +from __future__ import annotations + +import json +import logging +from abc import abstractmethod +from typing import Any + +from dsp.modules.aws_providers import AWSProvider, Bedrock, Sagemaker +from dsp.modules.lm import LM + +# Heuristic translating number of chars to tokens +# ~4 chars = 1 token +CHARS2TOKENS: int = 4 + + +class AWSModel(LM): + """This class adds support for an AWS model. + It is an abstract class and should not be instantiated directly. + Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta. + The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker. + Usage Example: + bedrock = dspy.Bedrock(region_name="us-west-2") + bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + + sagemaker = dspy.Sagemaker(region_name="us-west-2") + sagemaker_model = dspy.AWSMistral(sagemaker, "", **kwargs) + """ + + def __init__( + self, + model: str, + max_context_size: int, + max_new_tokens: int, + **kwargs, + ) -> None: + """_summary_. + + Args: + model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. + max_context_size (int): The maximum context size in tokens. + max_new_tokens (int): The maximum number of tokens to be sampled from the LM. + """ + super().__init__(model=model) + self._model_name: str = model + self._max_context_size: int = max_context_size + self._max_new_tokens: int = max_new_tokens + + self.kwargs = { + **self.kwargs, + **kwargs, + } + + @abstractmethod + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + pass + + @abstractmethod + def _call_model(self, body: str) -> str | list[str]: + """Call model, get generated input without the formatted prompt.""" + + def _estimate_tokens(self, text: str) -> int: + return len(text)/CHARS2TOKENS + + def _extract_input_parameters( + self, + body: dict[Any, Any], + ) -> dict[str, str | float | int]: + return body + + def _format_prompt(self, raw_prompt: str) -> str: + return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" + + def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]: + n, body = self._create_body(formatted_prompt, **kwargs) + json_body = json.dumps(body) + + if n > 1: + llm_out = [self._call_model(json_body) for _ in range(n)] + llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out] + else: + llm_out = self._call_model(json_body) + llm_out = llm_out.replace(formatted_prompt, "") + + self.history.append( + {"prompt": formatted_prompt, "response": llm_out, "kwargs": body}, + ) + return llm_out + + def basic_request(self, prompt, **kwargs) -> str | list[str]: + """Query the endpoint.""" + token_count = self._estimate_tokens(prompt) + if token_count > self._max_context_size: + logging.info("Error - input tokens %s exceeds max context %s", token_count, self._max_context_size) + raise ValueError( + f"Error - input tokens {token_count} exceeds max context {self._max_context_size}", + ) + + formatted_prompt: str = self._format_prompt(prompt) + return self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs) + + def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, + ) -> list[str]: + """Query the AWS LLM. + + There is only support for only_completed=True and return_sorted=False + right now. + """ + assert only_completed, "for now" + assert return_sorted is False, "for now" + + generated = self.basic_request(prompt, **kwargs) + return [generated] + + +class AWSMistral(AWSModel): + """Mistral family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 32768, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.aws_provider = aws_provider + self.provider = aws_provider.get_provider_name() + + self.kwargs["stop"] = "\n\n---" + + def _format_prompt(self, raw_prompt: str) -> str: + return " [INST] Human: " + raw_prompt + " [/INST] Assistant: " + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, base_args = self.aws_provider.sanitize_kwargs(base_args) + + query_args: dict[str, str | float] = {} + if isinstance(self.aws_provider, Bedrock): + query_args["prompt"] = prompt + elif isinstance(self.aws_provider, Sagemaker): + query_args["parameters"] = base_args + query_args["inputs"] = prompt + else: + raise ValueError("Error - provider not recognized") + + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.aws_provider.call_model( + model_id=self._model_name, + body=body, + ) + if isinstance(self.aws_provider, Bedrock): + response_body = json.loads(response["body"].read()) + completion = response_body["outputs"][0]["text"] + elif isinstance(self.aws_provider, Sagemaker): + response_body = json.loads(response["Body"].read()) + completion = response_body[0]["generated_text"] + else: + raise ValueError("Error - provider not recognized") + + completion = completion.split(self.kwargs["stop"])[0] + return completion + + +class AWSAnthropic(AWSModel): + """Anthropic family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 200000, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.aws_provider = aws_provider + self.provider = aws_provider.get_provider_name() + + if isinstance(self.aws_provider, Bedrock): + self.kwargs["anthropic_version"] = "bedrock-2023-05-31" + + for k, v in kwargs.items(): + self.kwargs[k] = v + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, query_args = self.aws_provider.sanitize_kwargs(base_args) + + # Anthropic models do not support the following parameters + query_args.pop("frequency_penalty", None) + query_args.pop("num_generations", None) + query_args.pop("presence_penalty", None) + query_args.pop("model", None) + + # we are using the Claude messages API + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + query_args["messages"] = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + ], + }, + ] + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.aws_provider.predictor.invoke_model( + modelId=self._model_name, + body=body, + ) + response_body = json.loads(response["body"].read()) + completion = response_body["content"][0]["text"] + return completion + + +class AWSMeta(AWSModel): + """Llama2 family of models.""" + + def __init__( + self, + aws_provider: AWSProvider, + model: str, + max_context_size: int = 4096, + max_new_tokens: int = 1500, + **kwargs, + ) -> None: + """NOTE: Configure your AWS credentials with the AWS CLI before using this model!""" + super().__init__( + model=model, + max_context_size=max_context_size, + max_new_tokens=max_new_tokens, + **kwargs, + ) + self.aws_provider = aws_provider + self.provider = aws_provider.get_provider_name() + + for k, v in kwargs.items(): + self.kwargs[k] = v + + self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens") + + def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: + base_args: dict[str, Any] = self.kwargs + for k, v in kwargs.items(): + base_args[k] = v + + n, query_args = self.aws_provider.sanitize_kwargs(base_args) + + # Meta models do not support the following parameters + query_args.pop("frequency_penalty", None) + query_args.pop("num_generations", None) + query_args.pop("presence_penalty", None) + query_args.pop("model", None) + + query_args["prompt"] = prompt + return (n, query_args) + + def _call_model(self, body: str) -> str: + response = self.aws_provider.predictor.invoke_model( + modelId=self._model_name, + body=body, + ) + response_body = json.loads(response["body"].read()) + completion = response_body["generation"] + + stop = "\n\n" + completion = completion.split(stop)[0] + + return completion diff --git a/dsp/modules/aws_providers.py b/dsp/modules/aws_providers.py new file mode 100644 index 0000000000..646cc0f431 --- /dev/null +++ b/dsp/modules/aws_providers.py @@ -0,0 +1,128 @@ +"""AWS providers for LMs.""" + +from abc import ABC, abstractmethod +from typing import Any, Optional + + +class AWSProvider(ABC): + """This abstract class adds support for AWS model providers such as Bedrock and SageMaker. + The subclasses such as Bedrock and Sagemaker implement the abstract method _call_model and work in conjunction with the AWSModel classes. + Usage Example: + bedrock = dspy.Bedrock(region_name="us-west-2") + bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) + bedrock_haiku = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs) + bedrock_llama2 = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs) + + sagemaker = dspy.Sagemaker(region_name="us-west-2") + sagemaker_model = dspy.AWSMistral(sagemaker, "", **kwargs) + """ + + def __init__( + self, + region_name: str, + service_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = True, + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + service_name (str): Used in context of invoking the boto3 API. + profile_name (str, optional): boto3 credentials profile. + batch_n_enabled (bool): If False, call the LM N times rather than batching. + """ + try: + import boto3 + except ImportError as exc: + raise ImportError('pip install boto3 to use AWS models.') from exc + + if profile_name is None: + self.predictor = boto3.client(service_name, region_name=region_name) + else: + self.predictor = boto3.Session(profile_name=profile_name).client( + service_name, + region_name=region_name, + ) + + self.batch_n_enabled = batch_n_enabled + + def get_provider_name(self) -> str: + """Return the provider name.""" + return self.__class__.__name__ + + @abstractmethod + def call_model(self, model_id: str, body: str) -> str: + """Call the model and return the response.""" + + def sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> tuple[int, dict[str, Any]]: + """Ensure that input kwargs can be used by Bedrock or Sagemaker.""" + if "temperature" in query_kwargs: + if query_kwargs["temperature"] > 0.99: + query_kwargs["temperature"] = 0.99 + if query_kwargs["temperature"] < 0.01: + query_kwargs["temperature"] = 0.01 + + if "top_p" in query_kwargs: + if query_kwargs["top_p"] > 0.99: + query_kwargs["top_p"] = 0.99 + if query_kwargs["top_p"] < 0.01: + query_kwargs["top_p"] = 0.01 + + n = -1 + if not self.batch_n_enabled: + n = query_kwargs.pop('n', 1) + query_kwargs["num_generations"] = n + + return n, query_kwargs + + +class Bedrock(AWSProvider): + """This class adds support for Bedrock models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + batch_n_enabled: bool = False, # This has to be setup manually on Bedrock. + ) -> None: + """_summary_. + + Args: + region_name (str): The AWS region where this LM is hosted. + profile_name (str, optional): boto3 credentials profile. + """ + super().__init__(region_name, "bedrock-runtime", profile_name, batch_n_enabled) + + def call_model(self, model_id: str, body: str) -> str: + return self.predictor.invoke_model( + modelId=model_id, + body=body, + accept="application/json", + contentType="application/json", + ) + + +class Sagemaker(AWSProvider): + """This class adds support for Sagemaker models.""" + + def __init__( + self, + region_name: str, + profile_name: Optional[str] = None, + ) -> None: + """_summary_. + + Args: + region_name (str, optional): The AWS region where this LM is hosted. + profile_name (str, optional): boto3 credentials profile. + """ + super().__init__(region_name, "runtime.sagemaker", profile_name) + + def call_model(self, model_id: str, body: str) -> str: + return self.predictor.invoke_endpoint( + EndpointName=model_id, + Body=body, + Accept="application/json", + ContentType="application/json", + ) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py deleted file mode 100644 index 252c87fe86..0000000000 --- a/dsp/modules/bedrock.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any, Optional - -from dsp.modules.aws_lm import AWSLM - - -class Bedrock(AWSLM): - def __init__( - self, - region_name: str, - model: str, - profile_name: Optional[str] = None, - input_output_ratio: int = 3, - max_new_tokens: int = 1500, - ) -> None: - """Use an AWS Bedrock language model. - NOTE: You must first configure your AWS credentials with the AWS CLI before using this model! - - Args: - region_name (str, optional): The AWS region where this LM is hosted. - model (str, optional): An AWS Bedrock LM name. You can find available models with the AWS CLI as follows: aws bedrock list-foundation-models --query "modelSummaries[*].modelId". - temperature (float, optional): Default temperature for LM. Defaults to 0. - input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. - max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. - """ - super().__init__( - model=model, - service_name="bedrock-runtime", - region_name=region_name, - profile_name=profile_name, - truncate_long_prompts=False, - input_output_ratio=input_output_ratio, - max_new_tokens=max_new_tokens, - batch_n=True, # Bedrock does not support the `n` parameter - ) - self._validate_model(model) - self.provider = "claude" if "claude" in model.lower() else "bedrock" - - def _validate_model(self, model: str) -> None: - if "claude" not in model.lower(): - raise NotImplementedError("Only claude models are supported as of now") - - def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]: - base_args: dict[str, Any] = { - "max_tokens_to_sample": self._max_new_tokens, - } - for k, v in kwargs.items(): - base_args[k] = v - query_args: dict[str, Any] = self._sanitize_kwargs(base_args) - query_args["prompt"] = prompt - # AWS Bedrock forbids these keys - if "max_tokens" in query_args: - max_tokens: int = query_args["max_tokens"] - input_tokens: int = self._estimate_tokens(prompt) - max_tokens_to_sample: int = max_tokens - input_tokens - del query_args["max_tokens"] - query_args["max_tokens_to_sample"] = max_tokens_to_sample - return query_args - - def _call_model(self, body: str) -> str: - response = self.predictor.invoke_model( - modelId=self._model_name, - body=body, - accept="application/json", - contentType="application/json", - ) - response_body = json.loads(response["body"].read()) - completion = response_body["completion"] - return completion - - def _extract_input_parameters( - self, body: dict[Any, Any], - ) -> dict[str, str | float | int]: - return body - - def _format_prompt(self, raw_prompt: str) -> str: - return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index 9f08439107..b1c26dc8da 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -46,7 +46,7 @@ def inspect_history(self, n: int = 1, skip: int = 0): prompt = x["prompt"] if prompt != last_prompt: - if provider == "clarifai" or provider == "google" or provider == "groq": + if provider == "clarifai" or provider == "google" or provider == "groq" or provider == "Bedrock" or provider == "Sagemaker": printed.append((prompt, x["response"])) elif provider == "anthropic": blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"] @@ -74,7 +74,7 @@ def inspect_history(self, n: int = 1, skip: int = 0): printing_value += prompt text = "" - if provider == "cohere": + if provider == "cohere" or provider == "Bedrock" or provider == "Sagemaker": text = choices elif provider == "openai" or provider == "ollama": text = ' ' + self._get_choice_text(choices[0]).strip() diff --git a/dspy/__init__.py b/dspy/__init__.py index 1c6539f59b..2d5ba4f003 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -33,7 +33,13 @@ Together = dsp.Together HFModel = dsp.HFModel OllamaLocal = dsp.OllamaLocal + Bedrock = dsp.Bedrock +Sagemaker = dsp.Sagemaker +AWSModel = dsp.AWSModel +AWSMistral = dsp.AWSMistral +AWSAnthropic = dsp.AWSAnthropic +AWSMeta = dsp.AWSMeta configure = settings.configure context = settings.context diff --git a/poetry.lock b/poetry.lock index 62b3b37ad3..d92a6a00e4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiohttp" @@ -530,6 +530,47 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "boto3" +version = "1.34.81" +description = "The AWS SDK for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.81-py3-none-any.whl", hash = "sha256:18224d206a8a775bcaa562d22ed3d07854934699190e12b52fcde87aac76a80e"}, + {file = "boto3-1.34.81.tar.gz", hash = "sha256:004dad209d37b3d2df88f41da13b7ad702a751904a335fac095897ff7a19f82b"}, +] + +[package.dependencies] +botocore = ">=1.34.81,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.81" +description = "Low-level, data-driven core of boto 3." +optional = true +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.81-py3-none-any.whl", hash = "sha256:85f6fd7c5715eeef7a236c50947de00f57d72e7439daed1125491014b70fab01"}, + {file = "botocore-1.34.81.tar.gz", hash = "sha256:f79bf122566cc1f09d71cc9ac9fcf52d47ba48b761cbc3f064017b36a3c40eb8"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.19.19)"] + [[package]] name = "cachetools" version = "5.3.3" @@ -2073,6 +2114,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = true +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "joblib" version = "1.3.2" @@ -4330,7 +4382,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4829,6 +4880,23 @@ files = [ {file = "ruff-0.3.4.tar.gz", hash = "sha256:f0f4484c6541a99862b693e13a151435a279b271cff20e37101116a21e2a1ad1"}, ] +[[package]] +name = "s3transfer" +version = "0.10.1" +description = "An Amazon S3 Transfer Manager" +optional = true +python-versions = ">= 3.8" +files = [ + {file = "s3transfer-0.10.1-py3-none-any.whl", hash = "sha256:ceb252b11bcf87080fb7850a224fb6e05c8a776bab8f2b64b7f25b969464839d"}, + {file = "s3transfer-0.10.1.tar.gz", hash = "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "safetensors" version = "0.4.2" @@ -5904,6 +5972,22 @@ files = [ {file = "ujson-5.9.0.tar.gz", hash = "sha256:89cc92e73d5501b8a7f48575eeb14ad27156ad092c2e9fc7e3cf949f07e75532"}, ] +[[package]] +name = "urllib3" +version = "1.26.18" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, + {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.1" @@ -6633,6 +6717,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] +aws = ["boto3"] chromadb = ["chromadb"] docs = ["autodoc_pydantic", "docutils", "furo", "m2r2", "myst-nb", "myst-parser", "sphinx", "sphinx-autobuild", "sphinx-automodapi", "sphinx-reredirects", "sphinx_rtd_theme"] marqo = ["marqo"] @@ -6645,4 +6730,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "582b069b4377e2b568dfdc2b43187f7c5cb3af12e1a4d6fc3a6fa7dde7edba10" +content-hash = "3988e0bedd832c87fda15828d8c6f08b2c3a9e75a9bca6d4201c5b8bdf5e3c9e" diff --git a/pyproject.toml b/pyproject.toml index 10274a71e3..f99df941a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ marqo = ["marqo"] pinecone = ["pinecone-client~=2.2.4"] weaviate = ["weaviate-client~=4.5.4"] milvus = ["pymilvus~=2.3.7"] +aws = ["boto3~=1.34.78"] docs = [ "sphinx>=4.3.0", "furo>=2023.3.27", @@ -59,7 +60,6 @@ docs = [ "autodoc_pydantic", "sphinx-reredirects>=0.1.2", "sphinx-automodapi==0.16.0", - ] dev = ["pytest>=6.2.5"] @@ -101,6 +101,7 @@ qdrant-client = { version = "^1.6.2", optional = true } pinecone-client = { version = "^2.2.4", optional = true } weaviate-client = { version = "^4.5.4", optional = true } pymilvus = { version = "^2.3.6", optional = true } +boto3 = { version = "^1.34.78", optional = true } sphinx = { version = ">=4.3.0", optional = true } furo = { version = ">=2023.3.27", optional = true } docutils = { version = "<0.17", optional = true } @@ -135,6 +136,7 @@ marqo = ["marqo"] pinecone = ["pinecone-client"] weaviate = ["weaviate-client"] milvus = ["pymilvus"] +aws = ["boto3"] postgres = ["psycopg2", "pgvector"] docs = [ "sphinx", diff --git a/tests/modules/test_aws_models.py b/tests/modules/test_aws_models.py new file mode 100644 index 0000000000..fd0794f7cd --- /dev/null +++ b/tests/modules/test_aws_models.py @@ -0,0 +1,62 @@ +"""Tests for AWS models. +Note: Requires configuration of your AWS credentials with the AWS CLI and creating sagemaker endpoints. +TODO: Create mock fixtures for pytest to remove the need for AWS credentials and endpoints. +""" + +import dsp +import dspy + +def get_lm(lm_provider: str, model_path: str, **kwargs) -> dsp.modules.lm.LM: + """get the language model""" + # extract model vendor and name from model name + # Model path format is / + model_vendor = model_path.split('/')[0] + model_name = model_path.split('/')[1] + + if lm_provider == 'Bedrock': + bedrock = dspy.Bedrock(region_name="us-west-2") + if model_vendor == 'mistral': + return dspy.AWSMistral(bedrock, model_name, **kwargs) + elif model_vendor == 'anthropic': + return dspy.AWSAnthropic(bedrock, model_name, **kwargs) + elif model_vendor == 'meta': + return dspy.AWSMeta(bedrock, model_name, **kwargs) + else: + raise ValueError("Model vendor missing or unsupported: Model path format is /") + elif lm_provider == 'Sagemaker': + sagemaker = dspy.Sagemaker(region_name="us-west-2") + if model_vendor == 'mistral': + return dspy.AWSMistral(sagemaker, model_name, **kwargs) + elif model_vendor == 'meta': + return dspy.AWSMeta(sagemaker, model_name, **kwargs) + else: + raise ValueError("Model vendor missing or unsupported: Model path format is /") + else: + raise ValueError(f"Unsupported model: {model_name}") + +def run_tests(): + """Test the providers and models""" + # Configure your AWS credentials with the AWS CLI before running this script + provider_model_tuples = [ + ('Bedrock', 'mistral/mistral.mixtral-8x7b-instruct-v0:1'), + ('Bedrock', 'anthropic/anthropic.claude-3-haiku-20240307-v1:0'), + ('Bedrock', 'anthropic/anthropic.claude-3-sonnet-20240229-v1:0'), + ('Bedrock', 'meta/meta.llama2-70b-chat-v1'), + # ('Sagemaker', 'mistral/'), # REPLACE YOUR_ENDPOINT_NAME with your sagemaker endpoint + ] + + predict_func = dspy.Predict("question -> answer") + for provider, model_path in provider_model_tuples: + print(f"Provider: {provider}, Model: {model_path}") + lm = get_lm(provider, model_path) + with dspy.context(lm=lm): + question = "What is the capital of France?" + answer = predict_func(question=question).answer + print(f"Question: {question}\nAnswer: {answer}") + print("---------------------------------") + lm.inspect_history() + print("---------------------------------\n") + + +if __name__ == "__main__": + run_tests()