diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py index 4c8937404a..ba2fa08488 100644 --- a/dsp/modules/aws_models.py +++ b/dsp/modules/aws_models.py @@ -17,9 +17,11 @@ class AWSModel(LM): """This class adds support for an AWS model. + It is an abstract class and should not be instantiated directly. Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta. - The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker. + The subclasses implement the abstract methods _create_body and _call_model + and work in conjunction with the AWSProvider classes Bedrock and Sagemaker. Usage Example: bedrock = dspy.Bedrock(region_name="us-west-2") bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs) @@ -43,6 +45,7 @@ def __init__( model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. max_context_size (int): The maximum context size in tokens. max_new_tokens (int): The maximum number of tokens to be sampled from the LM. + **kwargs: Additional arguments. """ super().__init__(model=model) self._model_name: str = model @@ -117,8 +120,10 @@ def __call__( There is only support for only_completed=True and return_sorted=False right now. """ - assert only_completed, "for now" - assert return_sorted is False, "for now" + if not only_completed: + raise ValueError("only_completed must be True for now") + if return_sorted: + raise ValueError("return_sorted must be False for now") generated = self.basic_request(prompt, **kwargs) return [generated] @@ -182,8 +187,7 @@ def _call_model(self, body: str) -> str: else: raise ValueError("Error - provider not recognized") - completion = completion.split(self.kwargs["stop"])[0] - return completion + return completion.split(self.kwargs["stop"])[0] class AWSAnthropic(AWSModel): @@ -247,12 +251,11 @@ def _call_model(self, body: str) -> str: body=body, ) response_body = json.loads(response["body"].read()) - completion = response_body["content"][0]["text"] - return completion + return response_body["content"][0]["text"] class AWSMeta(AWSModel): - """Llama2 family of models.""" + """Llama3 family of models.""" def __init__( self, @@ -275,10 +278,15 @@ def __init__( for k, v in kwargs.items(): self.kwargs[k] = v - self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens") + def _format_prompt(self, raw_prompt: str) -> str: + return ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>" + + raw_prompt + + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + ) def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - base_args: dict[str, Any] = self.kwargs + base_args: dict[str, Any] = self.kwargs.copy() for k, v in kwargs.items(): base_args[k] = v @@ -290,6 +298,10 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa query_args.pop("presence_penalty", None) query_args.pop("model", None) + max_tokens = query_args.pop("max_tokens", None) + if max_tokens: + query_args["max_gen_len"] = max_tokens + query_args["prompt"] = prompt return (n, query_args) @@ -299,9 +311,4 @@ def _call_model(self, body: str) -> str: body=body, ) response_body = json.loads(response["body"].read()) - completion = response_body["generation"] - - stop = "\n\n" - completion = completion.split(stop)[0] - - return completion + return response_body["generation"] diff --git a/tests/modules/test_aws_models.py b/tests/modules/test_aws_models.py index fd0794f7cd..b6e018b337 100644 --- a/tests/modules/test_aws_models.py +++ b/tests/modules/test_aws_models.py @@ -6,42 +6,50 @@ import dsp import dspy + def get_lm(lm_provider: str, model_path: str, **kwargs) -> dsp.modules.lm.LM: """get the language model""" # extract model vendor and name from model name # Model path format is / - model_vendor = model_path.split('/')[0] - model_name = model_path.split('/')[1] + model_vendor = model_path.split("/")[0] + model_name = model_path.split("/")[1] - if lm_provider == 'Bedrock': + if lm_provider == "Bedrock": bedrock = dspy.Bedrock(region_name="us-west-2") - if model_vendor == 'mistral': + if model_vendor == "mistral": return dspy.AWSMistral(bedrock, model_name, **kwargs) - elif model_vendor == 'anthropic': + elif model_vendor == "anthropic": return dspy.AWSAnthropic(bedrock, model_name, **kwargs) - elif model_vendor == 'meta': + elif model_vendor == "meta": return dspy.AWSMeta(bedrock, model_name, **kwargs) else: - raise ValueError("Model vendor missing or unsupported: Model path format is /") - elif lm_provider == 'Sagemaker': + raise ValueError( + "Model vendor missing or unsupported: Model path format is /" + ) + elif lm_provider == "Sagemaker": sagemaker = dspy.Sagemaker(region_name="us-west-2") - if model_vendor == 'mistral': + if model_vendor == "mistral": return dspy.AWSMistral(sagemaker, model_name, **kwargs) - elif model_vendor == 'meta': + elif model_vendor == "meta": return dspy.AWSMeta(sagemaker, model_name, **kwargs) else: - raise ValueError("Model vendor missing or unsupported: Model path format is /") + raise ValueError( + "Model vendor missing or unsupported: Model path format is /" + ) else: raise ValueError(f"Unsupported model: {model_name}") + def run_tests(): """Test the providers and models""" # Configure your AWS credentials with the AWS CLI before running this script provider_model_tuples = [ - ('Bedrock', 'mistral/mistral.mixtral-8x7b-instruct-v0:1'), - ('Bedrock', 'anthropic/anthropic.claude-3-haiku-20240307-v1:0'), - ('Bedrock', 'anthropic/anthropic.claude-3-sonnet-20240229-v1:0'), - ('Bedrock', 'meta/meta.llama2-70b-chat-v1'), + ("Bedrock", "mistral/mistral.mixtral-8x7b-instruct-v0:1"), + ("Bedrock", "anthropic/anthropic.claude-3-haiku-20240307-v1:0"), + ("Bedrock", "anthropic/anthropic.claude-3-sonnet-20240229-v1:0"), + ("Bedrock", "meta/meta.llama2-70b-chat-v1"), + ("Bedrock", "meta/meta.llama3-8b-instruct-v1:0"), + ("Bedrock", "meta/meta.llama3-70b-instruct-v1:0"), # ('Sagemaker', 'mistral/'), # REPLACE YOUR_ENDPOINT_NAME with your sagemaker endpoint ]