Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions dsp/modules/aws_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

class AWSModel(LM):
"""This class adds support for an AWS model.

It is an abstract class and should not be instantiated directly.
Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta.
The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker.
The subclasses implement the abstract methods _create_body and _call_model
and work in conjunction with the AWSProvider classes Bedrock and Sagemaker.
Usage Example:
bedrock = dspy.Bedrock(region_name="us-west-2")
bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs)
Expand All @@ -43,6 +45,7 @@ def __init__(
model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint.
max_context_size (int): The maximum context size in tokens.
max_new_tokens (int): The maximum number of tokens to be sampled from the LM.
**kwargs: Additional arguments.
"""
super().__init__(model=model)
self._model_name: str = model
Expand Down Expand Up @@ -117,8 +120,10 @@ def __call__(
There is only support for only_completed=True and return_sorted=False
right now.
"""
assert only_completed, "for now"
assert return_sorted is False, "for now"
if not only_completed:
raise ValueError("only_completed must be True for now")
if return_sorted:
raise ValueError("return_sorted must be False for now")

generated = self.basic_request(prompt, **kwargs)
return [generated]
Expand Down Expand Up @@ -182,8 +187,7 @@ def _call_model(self, body: str) -> str:
else:
raise ValueError("Error - provider not recognized")

completion = completion.split(self.kwargs["stop"])[0]
return completion
return completion.split(self.kwargs["stop"])[0]


class AWSAnthropic(AWSModel):
Expand Down Expand Up @@ -247,12 +251,11 @@ def _call_model(self, body: str) -> str:
body=body,
)
response_body = json.loads(response["body"].read())
completion = response_body["content"][0]["text"]
return completion
return response_body["content"][0]["text"]


class AWSMeta(AWSModel):
"""Llama2 family of models."""
"""Llama3 family of models."""

def __init__(
self,
Expand All @@ -275,10 +278,15 @@ def __init__(
for k, v in kwargs.items():
self.kwargs[k] = v

self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens")
def _format_prompt(self, raw_prompt: str) -> str:
return (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
+ raw_prompt
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
)

def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]:
base_args: dict[str, Any] = self.kwargs
base_args: dict[str, Any] = self.kwargs.copy()
for k, v in kwargs.items():
base_args[k] = v

Expand All @@ -290,6 +298,10 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa
query_args.pop("presence_penalty", None)
query_args.pop("model", None)

max_tokens = query_args.pop("max_tokens", None)
if max_tokens:
query_args["max_gen_len"] = max_tokens

query_args["prompt"] = prompt
return (n, query_args)

Expand All @@ -299,9 +311,4 @@ def _call_model(self, body: str) -> str:
body=body,
)
response_body = json.loads(response["body"].read())
completion = response_body["generation"]

stop = "\n\n"
completion = completion.split(stop)[0]

return completion
return response_body["generation"]
38 changes: 23 additions & 15 deletions tests/modules/test_aws_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,50 @@
import dsp
import dspy


def get_lm(lm_provider: str, model_path: str, **kwargs) -> dsp.modules.lm.LM:
"""get the language model"""
# extract model vendor and name from model name
# Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>
model_vendor = model_path.split('/')[0]
model_name = model_path.split('/')[1]
model_vendor = model_path.split("/")[0]
model_name = model_path.split("/")[1]

if lm_provider == 'Bedrock':
if lm_provider == "Bedrock":
bedrock = dspy.Bedrock(region_name="us-west-2")
if model_vendor == 'mistral':
if model_vendor == "mistral":
return dspy.AWSMistral(bedrock, model_name, **kwargs)
elif model_vendor == 'anthropic':
elif model_vendor == "anthropic":
return dspy.AWSAnthropic(bedrock, model_name, **kwargs)
elif model_vendor == 'meta':
elif model_vendor == "meta":
return dspy.AWSMeta(bedrock, model_name, **kwargs)
else:
raise ValueError("Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>")
elif lm_provider == 'Sagemaker':
raise ValueError(
"Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>"
)
elif lm_provider == "Sagemaker":
sagemaker = dspy.Sagemaker(region_name="us-west-2")
if model_vendor == 'mistral':
if model_vendor == "mistral":
return dspy.AWSMistral(sagemaker, model_name, **kwargs)
elif model_vendor == 'meta':
elif model_vendor == "meta":
return dspy.AWSMeta(sagemaker, model_name, **kwargs)
else:
raise ValueError("Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>")
raise ValueError(
"Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>"
)
else:
raise ValueError(f"Unsupported model: {model_name}")


def run_tests():
"""Test the providers and models"""
# Configure your AWS credentials with the AWS CLI before running this script
provider_model_tuples = [
('Bedrock', 'mistral/mistral.mixtral-8x7b-instruct-v0:1'),
('Bedrock', 'anthropic/anthropic.claude-3-haiku-20240307-v1:0'),
('Bedrock', 'anthropic/anthropic.claude-3-sonnet-20240229-v1:0'),
('Bedrock', 'meta/meta.llama2-70b-chat-v1'),
("Bedrock", "mistral/mistral.mixtral-8x7b-instruct-v0:1"),
("Bedrock", "anthropic/anthropic.claude-3-haiku-20240307-v1:0"),
("Bedrock", "anthropic/anthropic.claude-3-sonnet-20240229-v1:0"),
("Bedrock", "meta/meta.llama2-70b-chat-v1"),
("Bedrock", "meta/meta.llama3-8b-instruct-v1:0"),
("Bedrock", "meta/meta.llama3-70b-instruct-v1:0"),
# ('Sagemaker', 'mistral/<YOUR_ENDPOINT_NAME>'), # REPLACE YOUR_ENDPOINT_NAME with your sagemaker endpoint
]

Expand Down